未验证 提交 c9b57dcc 编写于 作者: Y Yu Yang 提交者: GitHub

ReadFromArray/WriteToArray op (#5407)

* Use stable_sort in lod_rank_table

It is easy to debug and test when use `stable_sort`and the time
complexity is not changed.

* Add LoDTensorArray

* Stash

* Better debug message for IsInitialized

* Stash

* Better debug message for IsInitialized

* Complete array read/write op unittests
上级 f78731c8
......@@ -349,6 +349,9 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
info.infer_var_type_(*this, block);
} else {
// all output type is LoDTensor by default
VLOG(10) << this->Type()
<< " has not registered InferVarType. Set output variables to "
"LOD_TENSOR";
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->Var(out_var_name)->SetType(VarDesc::LOD_TENSOR);
......
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
......@@ -118,6 +118,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
endif()
if ("${TARGET}" STREQUAL "tensor_array_read_write_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(write_to_array);\n")
endif()
# pybind USE_NO_KERNEL_OP
# HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
file(READ ${TARGET}.cc TARGET_CONTENT)
......@@ -161,6 +166,7 @@ set(DEPS_OPS
sequence_pool_op
lod_rank_table_op
lstm_op
tensor_array_read_write_op
gru_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
......@@ -171,6 +177,7 @@ op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
......
......@@ -35,7 +35,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
int data_type = ctx.Attr<int>("data_type");
VLOG(10) << " FillConstant data_type = " << data_type;
return static_cast<framework::DataType>(data_type);
}
};
......@@ -71,4 +73,5 @@ REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>);
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -20,4 +20,5 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant, ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>);
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int>,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -31,7 +31,6 @@ class IncrementOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IncrementOpMaker(framework::OpProto *proto,
......@@ -39,10 +38,10 @@ class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator.");
AddAttr<AttrType>("step",
"(float, default 1.0) "
"The step size by which the "
"input tensor will be incremented.")
AddAttr<float>("step",
"(float, default 1.0) "
"The step size by which the "
"input tensor will be incremented.")
.SetDefault(1.0);
AddComment(R"DOC(
Increment Operator.
......@@ -73,7 +72,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker<float>,
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(increment,
ops::IncrementKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUPlace, float>,
ops::IncrementKernel<paddle::platform::CPUPlace, double>,
ops::IncrementKernel<paddle::platform::CPUPlace, int>,
ops::IncrementKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -16,4 +16,7 @@
REGISTER_OP_GPU_KERNEL(
increment,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, float>);
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, float>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, double>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, int>,
paddle::operators::IncrementKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace operators {
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class IncrementKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
......@@ -27,7 +27,7 @@ class IncrementKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto step = static_cast<T>(context.Attr<AttrType>("step"));
auto step = static_cast<T>(context.Attr<float>("step"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class ArrayOpBase : public framework::OperatorBase {
public:
ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
protected:
size_t GetOffset(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const {
auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set");
auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
}
return offset;
}
};
class WriteToArrayOp : public ArrayOpBase {
public:
WriteToArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_tensor = x->Get<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
if (offset >= out->size()) {
out->resize(offset + 1);
}
auto *out_tensor = &out->at(offset);
out_tensor->CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx);
out_tensor->set_lod(x_tensor.lod());
}
};
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
WriteToArrayOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput(
"I",
"(Tensor) the subscript index in tensor array. The number of element "
"should be 1");
AddOutput("Out", "(TensorArray) the tensor array will be written");
AddComment(R"DOC(Write a LoDTensor to a LoDTensor array.
Assume T is LoDTensor, i is the subscript of the array, and A is the array. The
equation is
A[i] = T
)DOC");
}
};
class WriteToArrayInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
"The number of element of subscript index must be 1");
PADDLE_ENFORCE(context->HasInput("X"), NotHasXError());
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError());
context->SetOutputDim("Out", context->GetInputDim("X"));
}
protected:
virtual const char *NotHasXError() const { return "Must set the lod tensor"; }
virtual const char *NotHasOutError() const {
return "Must set the lod tensor array";
}
};
class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
}
}
};
class ReadFromArrayOp : public ArrayOpBase {
public:
ReadFromArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ArrayOpBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tesnor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
PADDLE_ENFORCE_LT(offset, x_array.size());
out_tesnor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx);
out_tesnor->set_lod(x_array[offset].lod());
}
};
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadFromArrayProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I",
"(Tensor) the subscript index in tensor array. The number of "
"element should be 1");
AddOutput("Out", "(LoDTensor) the tensor will be read from.");
AddComment(R"DOC(Read a LoDTensor from a LoDTensor Array
Assume T is LoDTensor, i is th e subscript of the array, and A is the array. The
equation is
T = A[i]
)DOC");
}
};
class ReadFromArrayInferShape : public WriteToArrayInferShape {
protected:
const char *NotHasXError() const override {
return "The input array X must be set";
}
const char *NotHasOutError() const override {
return "The output tensor out must be set";
}
};
class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("read_from_array");
grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("write_to_array");
grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(write_to_array, ops::WriteToArrayOp,
ops::WriteToArrayInferShape, ops::WriteToArrayOpProtoMaker,
ops::WriteToArrayGradMaker, ops::WriteToArrayInferVarType);
REGISTER_OPERATOR(read_from_array, ops::ReadFromArrayOp,
ops::ReadFromArrayInferShape, ops::ReadFromArrayProtoMaker,
ops::ReadFromArrayGradMaker);
......@@ -97,6 +97,15 @@ namespace pybind {
using namespace paddle::framework; // NOLINT
template <typename T>
static py::bytes SerializeMessage(T &self) {
// Check IsInitialized in Python
std::string retv;
PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
"Cannot serialize message");
return retv;
}
// Bind Methods
void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
......@@ -132,17 +141,7 @@ void BindProgramDesc(py::module &m) {
.def("block", &ProgramDescBind::MutableBlock,
py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
const ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"ProgramDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
})
.def("serialize_to_string", SerializeMessage<ProgramDescBind>)
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
......@@ -181,16 +180,7 @@ void BindBlockDesc(py::module &m) {
py::return_value_policy::reference)
.def("op_size", &BlockDescBind::OpSize)
.def("op", &BlockDescBind::Op, py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"BlockDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize BlockDesc Error. This could be a bug of Paddle.");
return res;
});
.def("serialize_to_string", SerializeMessage<BlockDescBind>);
}
void BindVarDsec(py::module &m) {
......@@ -219,17 +209,7 @@ void BindVarDsec(py::module &m) {
.def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType)
.def("serialize_to_string",
[](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle.");
return res;
})
.def("serialize_to_string", SerializeMessage<VarDescBind>)
.def("persistable", &VarDescBind::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable);
......@@ -274,16 +254,7 @@ void BindOpDesc(py::module &m) {
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape)
.def("infer_var_type", &OpDescBind::InferVarType)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"OpDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize OpDesc Error. This could be a bug of Paddle.");
return res;
});
.def("serialize_to_string", SerializeMessage<OpDescBind>);
}
} // namespace pybind
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program
from paddle.v2.framework.framework import Block, Program, g_main_program
g_scope = core.Scope()
......@@ -18,7 +18,7 @@ class Executor(object):
self.executor = core.Executor(act_places)
def run(self,
program,
program=None,
feed=None,
fetch_list=None,
feed_var_name='feed',
......@@ -29,6 +29,9 @@ class Executor(object):
if fetch_list is None:
fetch_list = []
if program is None:
program = g_main_program
if not isinstance(program, Program):
raise TypeError()
......
......@@ -12,6 +12,14 @@ def unique_name(prefix):
return "_".join([prefix, str(uid)])
def _debug_string_(proto):
error_fields = list()
if not proto.IsInitialized(error_fields):
raise ValueError("{0} are not initialized\nThe message is {1}".format(
error_fields, proto))
return proto.__str__()
class Variable(object):
def __init__(self,
block,
......@@ -95,7 +103,7 @@ class Variable(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -286,7 +294,7 @@ class Operator(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -343,7 +351,7 @@ class Block(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
__repr__ = __str__
......@@ -448,7 +456,7 @@ class Program(object):
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__()
return _debug_string_(proto)
def clone(self):
p = Program()
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.framework.initializer import ConstantInitializer, NormalInitializer
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.framework.initializer import ConstantInitializer, \
NormalInitializer
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import re
......@@ -751,3 +753,67 @@ def lod_rank_table(x, level=0, main_program=None):
outputs={'Out': table},
attrs={'level': level})
return table
def fill_constant(shape, dtype, value, main_program=None):
helper = LayerHelper("ones", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'data_type': out.data_type,
'value': float(value)
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, main_program=None):
return fill_constant(value=0.0, **locals())
def increment(x, value=1.0, main_program=None):
helper = LayerHelper("increment", **locals())
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [x]},
attrs={'step': value})
return x
def array_write(x, i, array=None, main_program=None):
helper = LayerHelper('array_write', **locals())
if array is None:
array = helper.create_variable(
name="{0}.out".format(helper.name),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=x.data_type)
helper.append_op(
type='write_to_array',
inputs={'X': [x],
'I': [i]},
outputs={'Out': [array]})
return array
def array_read(array, i, main_program=None):
helper = LayerHelper('array_read', **locals())
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError("array should be tensor array vairable")
out = helper.create_tmp_variable(dtype=array.data_type)
helper.append_op(
type='read_from_array',
inputs={'X': [array],
'I': [i]},
outputs={'Out': [out]})
return out
import unittest
import numpy
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
from paddle.v2.framework.executor import Executor
class TestArrayReadWrite(unittest.TestCase):
def test_read_write(self):
x = [
layers.data(
name='x0', shape=[100]), layers.data(
name='x1', shape=[100]), layers.data(
name='x2', shape=[100])
]
for each_x in x:
each_x.stop_gradient = False
i = layers.zeros(shape=[1], dtype='int64')
arr = layers.array_write(x=x[0], i=i)
layers.increment(x=i)
arr = layers.array_write(x=x[1], i=i, array=arr)
layers.increment(x=i)
arr = layers.array_write(x=x[2], i=i, array=arr)
i = layers.zeros(shape=[1], dtype='int64')
a0 = layers.array_read(array=arr, i=i)
layers.increment(x=i)
a1 = layers.array_read(array=arr, i=i)
layers.increment(x=i)
a2 = layers.array_read(array=arr, i=i)
mean_a0 = layers.mean(x=a0)
mean_a1 = layers.mean(x=a1)
mean_a2 = layers.mean(x=a2)
a_sum = layers.sums(input=[mean_a0, mean_a1, mean_a2])
mean_x0 = layers.mean(x=x[0])
mean_x1 = layers.mean(x=x[1])
mean_x2 = layers.mean(x=x[2])
x_sum = layers.sums(input=[mean_x0, mean_x1, mean_x2])
scope = core.Scope()
cpu = core.CPUPlace()
exe = Executor(cpu)
tensor = core.LoDTensor()
tensor.set(numpy.random.random(size=(100, 100)).astype('float32'), cpu)
outs = map(numpy.array,
exe.run(feed={'x0': tensor,
'x1': tensor,
'x2': tensor},
fetch_list=[a_sum, x_sum],
scope=scope))
self.assertEqual(outs[0], outs[1])
if __name__ == '__main__':
unittest.main()
import unittest
from paddle.v2.framework.framework import Program
class TestDebugStringFramework(unittest.TestCase):
def test_debug_str(self):
p = Program()
p.current_block().create_var(name='t', shape=[0, 1])
self.assertRaises(ValueError, callableObj=p.__str__)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册