提交 73f47798 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into backward2

...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/framework/enforce.h" #include "paddle/framework/enforce.h"
...@@ -41,6 +42,35 @@ class DefaultValueSetter { ...@@ -41,6 +42,35 @@ class DefaultValueSetter {
T default_value_; T default_value_;
}; };
template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val,
ContainerDebugString());
}
private:
std::string ContainerDebugString() const {
std::ostringstream sout;
sout << "[";
size_t cnt = 0;
for (auto& v : container_) {
sout << v;
++cnt;
if (cnt != container_.size()) {
sout << " ,";
}
}
sout << "]";
return sout.str();
}
std::unordered_set<T> container_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -50,6 +80,11 @@ class TypedAttrChecker { ...@@ -50,6 +80,11 @@ class TypedAttrChecker {
public: public:
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
value_checkers_.push_back(EnumInContainer<T>(range));
return *this;
}
TypedAttrChecker& LargerThan(const T& lower_bound) { TypedAttrChecker& LargerThan(const T& lower_bound) {
value_checkers_.push_back(LargerThanChecker<T>(lower_bound)); value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
return *this; return *this;
......
...@@ -33,7 +33,9 @@ std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) { ...@@ -33,7 +33,9 @@ std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
return grad_ops; return grad_ops;
} }
void PlainNet::CompleteAddOp() { void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set; std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output; std::unordered_set<std::string> temp_output;
...@@ -66,7 +68,6 @@ void PlainNet::CompleteAddOp() { ...@@ -66,7 +68,6 @@ void PlainNet::CompleteAddOp() {
} }
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
add_op_done_ = true;
} }
std::string PlainNet::DebugString() const { std::string PlainNet::DebugString() const {
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <paddle/framework/op_desc.pb.h> #include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
#include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -41,7 +40,7 @@ namespace framework { ...@@ -41,7 +40,7 @@ namespace framework {
class Net : public OperatorBase { class Net : public OperatorBase {
public: public:
virtual void AddOp(const OperatorPtr& op) = 0; virtual void AddOp(const OperatorPtr& op) = 0;
virtual void CompleteAddOp() = 0; virtual void CompleteAddOp(bool calc) = 0;
}; };
using NetPtr = std::shared_ptr<Net>; using NetPtr = std::shared_ptr<Net>;
...@@ -86,7 +85,7 @@ class PlainNet : public Net { ...@@ -86,7 +85,7 @@ class PlainNet : public Net {
ops_.push_back(op); ops_.push_back(op);
} }
void CompleteAddOp() override; void CompleteAddOp(bool calculate = true) override;
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -71,14 +71,14 @@ class Tensor { ...@@ -71,14 +71,14 @@ class Tensor {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T))); boost::get<platform::CPUPlace>(place), product(dims_) * sizeof(T)));
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
#ifdef __CUDACC__ #ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T))); boost::get<platform::GPUPlace>(place), product(dims_) * sizeof(T)));
#else
PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device.");
#endif #endif
} else { } else {
PADDLE_ENFORCE(true, "Unknown 'place'."); PADDLE_THROW("Unknown 'place'.");
} }
offset_ = 0; offset_ = 0;
} }
......
...@@ -359,12 +359,11 @@ void Layer::backwardActivation() { ...@@ -359,12 +359,11 @@ void Layer::backwardActivation() {
/* Do error clipping */ /* Do error clipping */
if (config_.error_clipping_threshold() > 0.0f) { if (config_.error_clipping_threshold() > 0.0f) {
if (FLAGS_log_error_clipping) { if (FLAGS_log_error_clipping) {
CpuVector outGradVec(0, nullptr); VectorPtr outGradVec = Vector::create(
outGradVec.subVecFrom( output_.grad->getData(), output_.grad->getElementCnt(), useGpu_);
output_.grad->getData(), 0, output_.grad->getElementCnt()); real maxAbsGrad = outGradVec->getAbsMax();
real maxAbsGrad = outGradVec.getAbsMax();
if (maxAbsGrad > config_.error_clipping_threshold()) { if (maxAbsGrad > config_.error_clipping_threshold()) {
real avgAbsGrad = outGradVec.getAbsSum() / outGradVec.getSize(); real avgAbsGrad = outGradVec->getAbsSum() / outGradVec->getSize();
LOG(INFO) << " layer=" << config_.name() << " need clipping," LOG(INFO) << " layer=" << config_.name() << " need clipping,"
<< " max error=" << maxAbsGrad << " avg error=" << avgAbsGrad; << " max error=" << maxAbsGrad << " avg error=" << avgAbsGrad;
} }
......
...@@ -27,7 +27,8 @@ function(op_library TARGET) ...@@ -27,7 +27,8 @@ function(op_library TARGET)
endif() endif()
list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_srcs cu_srcs_len)
if (${cu_srcs_len} EQUAL 0) list(LENGTH op_library_DEPS dep_len)
if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
message(WARNING "The op library ${TARGET} not support GPU!") message(WARNING "The op library ${TARGET} not support GPU!")
endif() endif()
...@@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) ...@@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
class FullyConnectedOp : public framework::PlainNet {
public:
void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")},
{}));
auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
{}));
}
auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false);
}
};
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator");
AddOutput(
"before_act", "the before activation output of fc operator", true);
AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"});
//! TODO(yuyang18): Complete comment;
AddComment("FullyConnected Operator");
}
};
} // namespace operators
} // namespace paddle
USE_OP(mul);
USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
REGISTER_OP(fc,
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op mul_op rowwise_add_op sigmoid_op softmax_op) add_op fc_op)
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/scope.h> #include <paddle/framework/scope.h>
#include <paddle/pybind/tensor_bind.h> #include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
...@@ -26,10 +27,7 @@ namespace py = pybind11; ...@@ -26,10 +27,7 @@ namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
USE_OP(add_two); USE_OP(add_two);
USE_OP(softmax); USE_OP_WITHOUT_KERNEL(fc);
USE_OP(mul);
USE_OP(rowwise_add);
USE_OP(sigmoid);
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle"); py::module m("core", "C++ core of Paddle Paddle");
...@@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) { ...@@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(paddle::platform::CPUPlace());
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>); .def("set", paddle::pybind::PyTensorSetFromArray<int>)
.def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); });
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class. py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class.
...@@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<std::string> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto& protos = pd::OpRegistry::protos(); auto& protos = pd::OpRegistry::protos();
std::vector<std::string> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = protos.begin(); it != protos.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), PADDLE_ENFORCE(it->second.IsInitialized(),
"OpProto must all be initialized"); "OpProto must all be initialized");
ret_values.emplace_back(); std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), PADDLE_ENFORCE(it->second.SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str));
} }
return ret_values; return ret_values;
}); });
...@@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle.
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
});
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator") py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString) .def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", [](const std::string& protobin) { .def_static("create",
pd::OpDesc desc; [](py::bytes protobin) {
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), pd::OpDesc desc;
"Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
PADDLE_ENFORCE(desc.IsInitialized(), "Cannot parse user input to OpDesc");
"User OpDesc is not initialized, reason %s", PADDLE_ENFORCE(desc.IsInitialized(),
desc.InitializationErrorString()); "User OpDesc is not initialized, reason %s",
return pd::OpRegistry::CreateOp(desc); desc.InitializationErrorString());
}); return pd::OpRegistry::CreateOp(desc);
})
.def("infer_shape", &pd::OperatorBase::InferShape)
.def("run", &pd::OperatorBase::Run)
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
return m.ptr(); return m.ptr();
} }
#!/bin/bash #!/bin/bash
function abort(){ function abort(){
echo "Your change doesn't follow PaddlePaddle's code style." 1>&2 echo "Your change doesn't follow PaddlePaddle's code style." 1>&2
echo "Please use pre-commit to reformat your code and git push again." 1>&2 echo "Please use pre-commit to check what is wrong." 1>&2
exit 1 exit 1
} }
...@@ -19,7 +19,8 @@ ln -sf $TRAVIS_BUILD_DIR $GOPATH/src/github.com/PaddlePaddle/Paddle ...@@ -19,7 +19,8 @@ ln -sf $TRAVIS_BUILD_DIR $GOPATH/src/github.com/PaddlePaddle/Paddle
cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd - cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd -
if ! pre-commit run -a ; then if ! pre-commit run -a ; then
git diff --exit-code git diff
exit 1
fi fi
trap : 0 trap : 0
...@@ -1575,7 +1575,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase): ...@@ -1575,7 +1575,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase):
@config_layer('fc') @config_layer('fc')
class FCLayer(LayerBase): class FCLayer(LayerBase):
def __init__(self, name, size, inputs, bias=True, **xargs): def __init__(self,
name,
size,
inputs,
bias=True,
error_clipping_threshold=None,
**xargs):
super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)): for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index) input_layer = self.get_input_layer(input_index)
...@@ -1592,6 +1598,8 @@ class FCLayer(LayerBase): ...@@ -1592,6 +1598,8 @@ class FCLayer(LayerBase):
self.create_input_parameter(input_index, psize, dims, sparse, self.create_input_parameter(input_index, psize, dims, sparse,
format) format)
self.create_bias_parameter(bias, self.config.size) self.create_bias_parameter(bias, self.config.size)
if error_clipping_threshold is not None:
self.config.error_clipping_threshold = error_clipping_threshold
@config_layer('selective_fc') @config_layer('selective_fc')
......
...@@ -26,8 +26,9 @@ import sentiment ...@@ -26,8 +26,9 @@ import sentiment
import wmt14 import wmt14
import mq2007 import mq2007
import flowers import flowers
import voc2012
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007', 'flowers' 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012'
] ]
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.voc2012
import unittest
class TestVOC(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, 3 * l[1].size)
sum += 1
return sum
def test_train(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.train())
self.assertEqual(count, 2913)
def test_test(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.test())
self.assertEqual(count, 1464)
def test_val(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.val())
self.assertEqual(count, 1449)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image dataset for segmentation.
The 2012 dataset contains images from 2008-2011 for which additional
segmentations have been prepared. As in previous years the assignment
to training/test sets has been maintained. The total number of images
with segmentation has been increased from 7,062 to 9,993.
"""
import tarfile
import io
import numpy as np
from paddle.v2.dataset.common import download
from paddle.v2.image import *
from PIL import Image
__all__ = ['train', 'test', 'val']
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
VOCtrainval_11-May-2012.tar'
VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
CACHE_DIR = 'voc2012'
def reader_creator(filename, sub_name):
tarobject = tarfile.open(filename)
name2mem = {}
for ele in tarobject.getmembers():
name2mem[ele.name] = ele
def reader():
set_file = SET_FILE.format(sub_name)
sets = tarobject.extractfile(name2mem[set_file])
for line in sets:
line = line.strip()
data_file = DATA_FILE.format(line)
label_file = LABEL_FILE.format(line)
data = tarobject.extractfile(name2mem[data_file]).read()
label = tarobject.extractfile(name2mem[label_file]).read()
data = Image.open(io.BytesIO(data))
label = Image.open(io.BytesIO(label))
data = np.array(data)
label = np.array(label)
yield data, label
return reader
def train():
"""
Create a train dataset reader containing 2913 images in HWC order.
"""
return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval')
def test():
"""
Create a test dataset reader containing 1464 images in HWC order.
"""
return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train')
def val():
"""
Create a val dataset reader containing 1449 images in HWC order.
"""
return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val')
add_python_test(test_framework test_protobuf.py test_scope.py add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py test_op_creation_methods.py test_default_scope_funcs.py test_op_creation_methods.py
test_tensor.py) test_tensor.py test_fc_op.py)
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase):
def test_fc(self):
scope = core.Scope(None)
x = scope.create_var("X")
x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784])
x_tensor.alloc_float()
w = scope.create_var("W")
w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100])
w_tensor.alloc_float()
w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
op = creation.op_creations.fc(X="X", Y="Y", W="W")
for out in op.outputs():
if scope.get_var(out) is None:
scope.create_var(out).get_tensor()
tensor = scope.get_var("Y").get_tensor()
op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context()
op.run(scope, ctx)
# After complete all ops, check Y is expect or not.
if __name__ == '__main__':
unittest.main()
...@@ -20,6 +20,7 @@ setup_requires=["requests", ...@@ -20,6 +20,7 @@ setup_requires=["requests",
"matplotlib", "matplotlib",
"rarfile", "rarfile",
"scipy>=0.19.0", "scipy>=0.19.0",
"Pillow",
"nltk"] "nltk"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册