提交 44e1ac38 编写于 作者: P peterzhang2029

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into bi_tensor_prod_op

#!/usr/bin/env python
from paddle.trainer_config_helpers import *
height = 224
width = 224
num_class = 1000
batch_size = get_config_arg('batch_size', int, 64)
layer_num = get_config_arg("layer_num", int, 50)
is_test = get_config_arg("is_test", bool, False)
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
define_py_data_sources2(
"train.list", None, module="provider", obj="process", args=args)
settings(
batch_size=batch_size,
learning_rate=0.01 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))
#######################Network Configuration #############
def conv_bn_layer(name,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
active_type=ReluActivation()):
"""
A wrapper for conv layer with batch normalization layers.
Note:
conv layer has no activation.
"""
tmp = img_conv_layer(
name=name + "_conv",
input=input,
filter_size=filter_size,
num_channels=channels,
num_filters=num_filters,
stride=stride,
padding=padding,
act=LinearActivation(),
bias_attr=False)
return batch_norm_layer(
name=name + "_bn", input=tmp, act=active_type, use_global_stats=is_test)
def bottleneck_block(name, input, num_filters1, num_filters2):
"""
A wrapper for bottlenect building block in ResNet.
Last conv_bn_layer has no activation.
Addto layer has activation of relu.
"""
last_name = conv_bn_layer(
name=name + '_branch2a',
input=input,
filter_size=1,
num_filters=num_filters1,
stride=1,
padding=0)
last_name = conv_bn_layer(
name=name + '_branch2b',
input=last_name,
filter_size=3,
num_filters=num_filters1,
stride=1,
padding=1)
last_name = conv_bn_layer(
name=name + '_branch2c',
input=last_name,
filter_size=1,
num_filters=num_filters2,
stride=1,
padding=0,
active_type=LinearActivation())
return addto_layer(
name=name + "_addto", input=[input, last_name], act=ReluActivation())
def mid_projection(name, input, num_filters1, num_filters2, stride=2):
"""
A wrapper for middile projection in ResNet.
projection shortcuts are used for increasing dimensions,
and other shortcuts are identity
branch1: projection shortcuts are used for increasing
dimensions, has no activation.
branch2x: bottleneck building block, shortcuts are identity.
"""
# stride = 2
branch1 = conv_bn_layer(
name=name + '_branch1',
input=input,
filter_size=1,
num_filters=num_filters2,
stride=stride,
padding=0,
active_type=LinearActivation())
last_name = conv_bn_layer(
name=name + '_branch2a',
input=input,
filter_size=1,
num_filters=num_filters1,
stride=stride,
padding=0)
last_name = conv_bn_layer(
name=name + '_branch2b',
input=last_name,
filter_size=3,
num_filters=num_filters1,
stride=1,
padding=1)
last_name = conv_bn_layer(
name=name + '_branch2c',
input=last_name,
filter_size=1,
num_filters=num_filters2,
stride=1,
padding=0,
active_type=LinearActivation())
return addto_layer(
name=name + "_addto", input=[branch1, last_name], act=ReluActivation())
img = data_layer(name='image', size=height * width * 3)
def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3):
"""
A wrapper for 50,101,152 layers of ResNet.
res2_num: number of blocks stacked in conv2_x
res3_num: number of blocks stacked in conv3_x
res4_num: number of blocks stacked in conv4_x
res5_num: number of blocks stacked in conv5_x
"""
# For ImageNet
# conv1: 112x112
tmp = conv_bn_layer(
"conv1",
input=img,
filter_size=7,
channels=3,
num_filters=64,
stride=2,
padding=3)
tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2)
# conv2_x: 56x56
tmp = mid_projection(
name="res2_1", input=tmp, num_filters1=64, num_filters2=256, stride=1)
for i in xrange(2, res2_num + 1, 1):
tmp = bottleneck_block(
name="res2_" + str(i), input=tmp, num_filters1=64, num_filters2=256)
# conv3_x: 28x28
tmp = mid_projection(
name="res3_1", input=tmp, num_filters1=128, num_filters2=512)
for i in xrange(2, res3_num + 1, 1):
tmp = bottleneck_block(
name="res3_" + str(i),
input=tmp,
num_filters1=128,
num_filters2=512)
# conv4_x: 14x14
tmp = mid_projection(
name="res4_1", input=tmp, num_filters1=256, num_filters2=1024)
for i in xrange(2, res4_num + 1, 1):
tmp = bottleneck_block(
name="res4_" + str(i),
input=tmp,
num_filters1=256,
num_filters2=1024)
# conv5_x: 7x7
tmp = mid_projection(
name="res5_1", input=tmp, num_filters1=512, num_filters2=2048)
for i in xrange(2, res5_num + 1, 1):
tmp = bottleneck_block(
name="res5_" + str(i),
input=tmp,
num_filters1=512,
num_filters2=2048)
tmp = img_pool_layer(
name='avgpool',
input=tmp,
pool_size=7,
stride=1,
pool_type=AvgPooling())
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
if layer_num == 50:
resnet = deep_res_net(3, 4, 6, 3)
elif layer_num == 101:
resnet = deep_res_net(3, 4, 23, 3)
elif layer_num == 152:
resnet = deep_res_net(3, 8, 36, 3)
else:
print("Wrong layer number.")
lbl = data_layer(name="label", size=num_class)
loss = cross_entropy(name='loss', input=resnet, label=lbl)
inputs(img, lbl)
outputs(loss)
......@@ -5,22 +5,23 @@ function train() {
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"
topology=$1
bs=$2
use_mkldnn=$3
if [ $3 == "True" ]; then
layer_num=$2
bs=$3
use_mkldnn=$4
if [ $4 == "True" ]; then
thread=1
log="logs/${topology}-mkldnn-${bs}.log"
elif [ $3 == "False" ]; then
log="logs/${topology}-${layer_num}-mkldnn-${bs}.log"
elif [ $4 == "False" ]; then
thread=`nproc`
# each trainer_count use only 1 core to avoid conflict
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
log="logs/${topology}-${thread}mklml-${bs}.log"
log="logs/${topology}-${layer_num}-${thread}mklml-${bs}.log"
else
echo "Wrong input $3, use True or False."
exit 0
fi
args="batch_size=${bs}"
args="batch_size=${bs},layer_num=${layer_num}"
config="${topology}.py"
paddle train --job=time \
--config=$config \
......@@ -40,12 +41,9 @@ if [ ! -d "logs" ]; then
mkdir logs
fi
#========== mkldnn ==========#
train vgg 64 True
train vgg 128 True
train vgg 256 True
#========== mklml ===========#
train vgg 64 False
train vgg 128 False
train vgg 256 False
for use_mkldnn in True False; do
for batchsize in 64 128 256; do
train vgg 19 $batchsize $use_mkldnn
train resnet 50 $batchsize $use_mkldnn
done
done
......@@ -13,7 +13,7 @@ define_py_data_sources2(
settings(
batch_size=batch_size,
learning_rate=0.01 / batch_size,
learning_rate=0.001 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))
......
......@@ -55,6 +55,6 @@ After float16 class is available, some of the future items are below:
- Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
- Modify `IndicateDataType()` method in `framework/operator.h` to make it compatible with float16.
- Modify `GetKernelType()` method in `framework/operator.h` to make it compatible with float16.
- Create a type-casting operator that can convert the data type in tensor between float16 and other types.
......@@ -117,7 +117,7 @@ int64_t DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
}
int64_t DDim::size() const { return arity(*this); }
int DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) {
......
......@@ -71,7 +71,7 @@ struct DDim {
DDim operator*(DDim d) const;
int64_t size() const;
int size() const;
};
/**
......
......@@ -31,6 +31,7 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) {
TableItem item;
item.index = i;
item.length = vec[i + 1] - vec[i];
VLOG(10) << "Add item to rank table " << item.index << " " << item.length;
items_.emplace_back(item);
}
// NOTE(yuyang18):
......
......@@ -27,6 +27,20 @@
namespace paddle {
namespace framework {
std::ostream& operator<<(std::ostream& os, const LoD& lod) {
os << "{";
for (auto& v : lod) {
os << "{";
for (auto& i : v) {
os << i << ",";
}
os << "}";
}
os << "}";
return os;
}
LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end) {
LoD new_lod;
new_lod.reserve(level_end - level_begin);
......@@ -136,37 +150,35 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
ShareDataWith(Slice(begin, end));
}
void GetFineGrainedLoDLength(const LoD& lod, size_t start_idx, size_t end_idx,
std::vector<std::vector<size_t>>* lod_length,
size_t* start_offset) {
lod_length->clear();
PADDLE_ENFORCE(start_idx < lod.size() - 1,
"start_idx should be >= 0 and < lod.size() - 1.");
PADDLE_ENFORCE(end_idx < lod.size(),
"end_idx should be >= 0 and < lod.size().");
PADDLE_ENFORCE_LE(start_idx, end_idx,
"start_idx should be less than end_idx.");
for (size_t level_idx = 0; level_idx < lod.size(); ++level_idx) {
using LoDAndOffset = std::pair<LoD, std::pair<size_t, size_t>>;
LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD& lod, size_t start_idx,
size_t end_idx, size_t start_level) {
LoD sub_lod;
for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) {
PADDLE_ENFORCE_LE(start_idx, end_idx);
PADDLE_ENFORCE_LT(end_idx, lod[level_idx].size());
std::vector<size_t> level_lens;
for (size_t i = start_idx; i < end_idx; ++i) {
level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]);
}
lod_length->emplace_back(level_lens);
sub_lod.emplace_back(level_lens);
start_idx = lod[level_idx][start_idx];
end_idx = lod[level_idx][end_idx];
}
*start_offset = start_idx;
return LoDAndOffset{sub_lod, {start_idx, end_idx}};
}
void AppendLoD(LoD* lod, const std::vector<std::vector<size_t>>& lod_length) {
PADDLE_ENFORCE_EQ(
lod->size(), lod_length.size(),
void AppendLoD(LoD* lod, const LoD& lod_length) {
PADDLE_ENFORCE(
lod->empty() || lod->size() == lod_length.size(),
"The lod_length should has the same size with the appended lod.");
if (lod->empty()) {
*lod = LoD(lod_length.size(), std::vector<size_t>({0}));
}
for (size_t i = 0; i < lod->size(); ++i) {
auto& level = (*lod)[i];
if (level.empty()) {
level.push_back(0);
}
for (size_t len : lod_length[i]) {
level.push_back(level.back() + len);
}
......
......@@ -56,6 +56,8 @@ using Vector = thrust::host_vector<
*/
using LoD = std::vector<Vector<size_t>>;
std::ostream& operator<<(std::ostream& os, const LoD& lod);
/*
* Slice levels from a LoD.
* NOTE the lowest level should always be the absolute offsets of the underlying
......@@ -181,11 +183,10 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
return tensor;
}
void GetFineGrainedLoDLength(const LoD& lod, size_t start_idx, size_t end_idx,
std::vector<std::vector<size_t>>* lod_length,
size_t* start_offset);
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level);
void AppendLoD(LoD* lod, const std::vector<std::vector<size_t>>& lod_length);
void AppendLoD(LoD* lod, const LoD& lod_length);
} // namespace framework
} // namespace paddle
......@@ -146,43 +146,44 @@ TEST(LodExpand, test) {
TEST(LoD, GetFineGrainedLoDLength) {
LoD lod;
lod.push_back(std::vector<size_t>{0, 2, 4, 5});
lod.push_back(std::vector<size_t>{0, 1, 6, 8, 10, 11});
lod.push_back(std::vector<size_t>({0, 2, 4, 5}));
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));
lod.push_back(
std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20, 24, 26, 29});
std::vector<size_t>({0, 2, 5, 7, 10, 12, 15, 17, 20, 24, 26, 29}));
std::vector<std::vector<size_t>> lod_length;
size_t start_offset;
paddle::framework::GetFineGrainedLoDLength(lod, 1, 2, &lod_length,
&start_offset);
auto lod_and_offset =
paddle::framework::GetSubLoDAndAbsoluteOffset(lod, 1, 2, 0);
LoD lod_length = lod_and_offset.first;
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
std::vector<std::vector<size_t>> expected;
LoD expected;
expected.push_back(std::vector<size_t>{2});
expected.push_back(std::vector<size_t>{2, 2});
expected.push_back(std::vector<size_t>{2, 3, 4, 2});
EXPECT_EQ(lod_length, expected);
EXPECT_EQ(start_offset, 15UL);
EXPECT_EQ(end_offset, 26UL);
}
TEST(LoD, AppendLoD) {
std::vector<std::vector<size_t>> lod_lens;
lod_lens.push_back(std::vector<size_t>{2});
lod_lens.push_back(std::vector<size_t>{2, 2});
lod_lens.push_back(std::vector<size_t>{2, 3, 4, 2});
LoD lod_lens;
lod_lens.push_back(std::vector<size_t>({2}));
lod_lens.push_back(std::vector<size_t>({2, 2}));
lod_lens.push_back(std::vector<size_t>({2, 3, 4, 2}));
LoD origin;
origin.push_back(std::vector<size_t>{0, 2});
origin.push_back(std::vector<size_t>{0, 1, 6});
origin.push_back(std::vector<size_t>{0, 2, 5, 7, 10, 12, 15});
origin.push_back(std::vector<size_t>({0, 2}));
origin.push_back(std::vector<size_t>({0, 1, 6}));
origin.push_back(std::vector<size_t>({0, 2, 5, 7, 10, 12, 15}));
paddle::framework::AppendLoD(&origin, lod_lens);
LoD expected;
expected.push_back(std::vector<size_t>{0, 2, 4});
expected.push_back(std::vector<size_t>{0, 1, 6, 8, 10});
expected.push_back(std::vector<size_t>({0, 2, 4}));
expected.push_back(std::vector<size_t>({0, 1, 6, 8, 10}));
expected.push_back(
std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20, 24, 26});
std::vector<size_t>({0, 2, 5, 7, 10, 12, 15, 17, 20, 24, 26}));
EXPECT_EQ(origin, expected);
}
......
......@@ -92,8 +92,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
......
......@@ -254,8 +254,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key) {
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
......@@ -432,7 +431,7 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
......@@ -440,6 +439,41 @@ void OperatorWithKernel::Run(const Scope& scope,
}
kernel_iter->second->Compute(ctx);
// throws errors if have.
dev_ctx.Finish();
}
OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.device_context());
}
DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same.", Type());
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
} // namespace framework
......
......@@ -345,38 +345,38 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};
platform::Place place_;
DataType data_type_;
OpKernelKey(DataType data_type, platform::Place place)
OpKernelType(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelKey& o) const {
bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
struct OpKernelHash {
std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};
};
class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
OpKernelType::Hash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
......@@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same.",
Type());
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
DataType IndicateDataType(const ExecutionContext& ctx) const;
};
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key);
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -114,8 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.device_context());
}
};
......
......@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
};
static inline size_t SizeOfType(std::type_index type) {
SizeOfTypeFunctor<int, float, double, int16_t, int64_t> functor;
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool> functor;
size_t size = functor(type);
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
return size;
......
......@@ -45,7 +45,8 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
desc_.mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel", desc_.type());
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
}
}
......@@ -56,7 +57,8 @@ int32_t VarDescBind::GetLodLevel() const {
case VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel", desc_.type());
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
}
}
......
......@@ -60,18 +60,16 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() {
}
CHECK(wgtVal_) << "should have been initialized";
bool hasNoSpatial_ = ih_ == 1 && iw_ == 1;
auto targetDim = wgtVal_->getDims();
auto srcFmt = hasNoSpatial_ ? format::io : format::ihwo;
auto srcFmt = targetDim.size() == 2 ? format::io : format::ihwo;
wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim);
hasInitedWgt_ = true;
}
void MKLDNNFcLayer::convertWeightsToPaddle() {
CHECK(wgtVal_) << "should have been initialized";
bool hasNoSpatial_ = ih_ == 1 && iw_ == 1;
auto targetDim = wgtVal_->getDims();
auto dstFmt = hasNoSpatial_ ? format::io : format::ihwo;
auto dstFmt = targetDim.size() == 2 ? format::io : format::ihwo;
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
}
......
......@@ -181,21 +181,17 @@ void MKLDNNLayer::resetInValue(
auto extPD = MKLDNNMatrix::createPrimitiveDesc(
{bs_, ic_, ih_, iw_}, format::nchw, engine_);
const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue();
in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr);
if (in == nullptr || in->getFormat() == format::nc) {
in = MKLDNNMatrix::create(extPD, inMat);
}
extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr;
if (in->getFormat() == format::nc) {
CHECK(ih_ == 1 && iw_ == 1);
extInVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr);
if (extInVal_ == nullptr || extInVal_->getFormat() == format::nc) {
extInVal_ = MKLDNNMatrix::create(extPD, inMat);
}
in = extInVal_;
if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) {
return;
}
// need create reorder
in = MKLDNNMatrix::create(*intPD);
extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat);
cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in);
CHECK(cvtInVal_) << "should not be emptry";
}
......
......@@ -62,6 +62,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
......@@ -165,6 +170,8 @@ set(DEPS_OPS
sequence_conv_op
sequence_pool_op
lod_rank_table_op
lod_tensor_to_array_op
array_to_lod_tensor_op
lstm_op
tensor_array_read_write_op
gru_op)
......@@ -177,6 +184,8 @@ op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)
op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op)
op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
......@@ -186,8 +195,13 @@ op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
if(WITH_TESTING)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array gtest)
else()
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array)
endif()
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
......@@ -47,10 +47,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context());
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <numeric>
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
using LoD = framework::LoD;
class ArrayToLoDTensorOp : public framework::OperatorBase {
public:
ArrayToLoDTensorOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
// Check dims, place and data type of input's elements and infer output's
// dim
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size();
platform::Place place = x[0].place();
std::type_index data_type = x[0].type();
framework::DDim ins_dims = framework::slice_ddim(x[0].dims(), 1, rank);
int64_t batch_size = x[0].dims()[0];
for (size_t i = 1; i < x.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x[i].dims(), 1, rank), ins_dims,
"The dimension of the %zu'th element in LoDTensorArray "
"differs from previous ones.",
i);
PADDLE_ENFORCE(platform::places_are_same_class(x[i].place(), place),
"The place class of the %zu'th element in LoDTensorArray "
"differs from previous ones.",
i);
PADDLE_ENFORCE(x[i].type() == data_type,
"The date type of the %zu'th element in LoDTensorArray "
"differs from previous ones.",
i);
batch_size += x[i].dims()[0];
}
auto ins_dim_vec = framework::vectorize(ins_dims);
ins_dim_vec.insert(ins_dim_vec.begin(), batch_size);
framework::DDim out_dims = framework::make_ddim(ins_dim_vec);
out->Resize(out_dims);
out->mutable_data(place, data_type);
auto &table_items = rank_table.items();
std::vector<size_t> table_item_idx(table_items.size());
// table_item_idx = range(table_items_idx.size())
std::iota(table_item_idx.begin(), table_item_idx.end(), 0);
std::sort(table_item_idx.begin(), table_item_idx.end(),
[&](size_t a, size_t b) {
return table_items[a].index < table_items[b].index;
});
// Build LoDTensor `out`
framework::LoD *out_lod = out->mutable_lod();
out_lod->clear();
size_t out_offset = 0;
auto prefix_lod = rank_table.coarse_lod();
prefix_lod.emplace_back();
auto &cur_level_lod = prefix_lod.back();
cur_level_lod.push_back(0);
for (size_t idx : table_item_idx) {
cur_level_lod.push_back(cur_level_lod.back() + table_items[idx].length);
for (size_t x_idx = 0; x_idx < table_items[idx].length; ++x_idx) {
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
x[x_idx].lod(), idx, idx + 1, 0);
auto &lod_length = lod_and_offset.first;
framework::AppendLoD(out_lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
VLOG(10) << "idx=" << idx << " x_idx=" << x_idx << " ["
<< ", " << end_offset << "]";
// Copy data
PADDLE_ENFORCE_GE(end_offset, start_offset);
size_t len = end_offset - start_offset;
if (len == 0) {
continue;
}
out->Slice(out_offset, out_offset + len)
.CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx);
out_offset += len;
}
}
out_lod->insert(out_lod->begin(), prefix_lod.begin(), prefix_lod.end());
}
};
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ArrayToLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor.");
AddInput("RankTable",
"(LoDRankTable) RankTable provides the coarse lod infomation to "
"build the output LoDTensor. See "
"'paddle/framework/lod_rank_table.h' for more details.");
AddOutput("Out", "(LoDTensor) The LoDTensor formed by input tensor array.");
AddComment(
R"DOC(This Op build a big LoDTensor from a std::vector<LoDTensor>
and a LoDRankTable. It is supposed to be used in getting dynamic RNN's
outputs back to a normal LoDTensor. The std::vector<LoDTensor>
would be the output of RNN Op and the LoDRankTable would be build
with RNN's input.)DOC");
}
};
class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"ArrayToLoDTensorOp must has input X.");
PADDLE_ENFORCE(context->HasInput("RankTable"),
"ArrayToLoDTensorOp must has input RankTable.");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp,
ops::ArrayToLoDTensorOpProtoMaker,
ops::ArrayToLoDTensorInferShape);
......@@ -39,10 +39,11 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context());
}
};
......
......@@ -303,7 +303,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
framework::DataType IndicateDataType(
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......@@ -318,7 +319,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::ToDataType(t->type());
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.device_context());
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CompareOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
comment.type));
AddInput("Y", string::Sprintf(
"(LoDTensor) the right hand operand of %s operator",
comment.type));
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X and Y, and returns the Out. Each of them is a
N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s
)DOC",
comment.type, comment.equation));
}
};
template <typename OpComment>
class CompareOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
comment.type);
PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must has input Y",
comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
struct LessThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
};
template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8;
} else {
return (a == b);
}
}
};
template <typename Place, typename Functor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
binary_func);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<int>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<int64_t>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<float>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<double>>);
......@@ -120,9 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
};
} // namespace operators
......
......@@ -51,9 +51,11 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......@@ -98,9 +100,11 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -49,9 +49,11 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
......@@ -33,11 +33,12 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
int data_type = ctx.Attr<int>("data_type");
VLOG(10) << " FillConstant data_type = " << data_type;
return static_cast<framework::DataType>(data_type);
return framework::OpKernelType(static_cast<framework::DataType>(data_type),
ctx.device_context());
}
};
......
......@@ -40,9 +40,11 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......@@ -55,9 +57,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -57,9 +57,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
......@@ -183,9 +183,11 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
};
......@@ -240,10 +242,13 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type());
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
ctx.device_context());
}
};
......
......@@ -28,6 +28,7 @@ class LoDRankTableOp : public framework::OperatorBase {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
VLOG(10) << "Level = " << static_cast<size_t>(Attr<int>("level"));
out->Reset(x.lod(), static_cast<size_t>(Attr<int>("level")));
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
struct CopyRange {
size_t begin;
size_t end;
};
class LoDTensorToArrayOp : public framework::OperatorBase {
public:
LoDTensorToArrayOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
auto &items = rank_table.items();
auto max_seq_len = items[0].length;
auto rank_level = rank_table.level();
out.resize(max_seq_len);
std::vector<std::vector<CopyRange>> copy_ranges(max_seq_len);
// set out[i] lod
for (size_t t = 0; t < max_seq_len; t++) {
auto &lod = *out[t].mutable_lod();
lod.clear();
for (auto &item : items) {
if (t >= item.length) {
break;
}
size_t start_idx = x.lod()[rank_level][item.index] + t;
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
x.lod(), start_idx, start_idx + 1, rank_level + 1);
auto &lod_length = lod_and_offset.first;
framework::AppendLoD(&lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset});
}
}
for (size_t i = 0; i < max_seq_len; ++i) {
auto &ranges = copy_ranges[i];
size_t height = std::accumulate(
ranges.begin(), ranges.end(), 0UL,
[](size_t a, const CopyRange &b) { return a + b.end - b.begin; });
auto x_dim = x.dims();
x_dim[0] = static_cast<int64_t>(height);
out[i].Resize(x_dim);
out[i].mutable_data(x.place(), x.type());
size_t offset = 0;
for (auto &each_range : ranges) {
size_t len = each_range.end - each_range.begin;
if (len == 0) {
continue;
}
// out[i][offset: offset+len] = x[each_range.begin: each_range.end]
out[i]
.Slice(static_cast<int>(offset), static_cast<int>(offset + len))
.CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx);
offset += len;
}
}
}
};
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDTensorToArrayOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("RankTable", "");
AddOutput("Out", "");
AddComment("");
}
};
class LoDTensorToArrayInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of LoDTensorToArrayOp should not be null.");
PADDLE_ENFORCE(
context->HasInput("RankTable"),
"Input(RankTable) of LoDTensorToArrayOp should not be null.");
PADDLE_ENFORCE(context->HasOutput("Out"),
"Output(Out) of LoDTensorToArrayOp should not be null.");
auto x_dim = context->GetInputDim("X");
// The first dim of each LoDTensor in Output can only be set at run-time.;
// We still have to Resize each LoDTensor in Output.
context->SetOutputDim("Out", x_dim);
}
};
class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lod_tensor_to_array, ops::LoDTensorToArrayOp,
ops::LoDTensorToArrayOpProtoMaker,
ops::LoDTensorToArrayInferShape,
ops::LoDTensorToArrayInferVarType);
......@@ -41,9 +41,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
}
};
......@@ -97,9 +99,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
}
};
......
......@@ -84,10 +84,11 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
}
};
......@@ -245,10 +246,11 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
}
};
......
......@@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
}
cudaStreamSynchronize(stream);
// TODO(qingqing): Add cuda error check for each kernel.
cudaError_t err = cudaGetLastError();
PADDLE_ENFORCE(err, cudaGetErrorString(err));
}
} // namespace detail
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/data_type.h"
namespace paddle {
namespace operators {
......@@ -233,6 +234,52 @@ void gemv<platform::CPUPlace, double>(const platform::DeviceContext& context,
template struct SetConstant<platform::CPUPlace, float>;
struct TensorSetConstant {
TensorSetConstant(framework::Tensor* tensor, float value)
: tensor_(tensor), value_(value) {}
template <typename T>
void operator()() const {
auto cpu = platform::CPUPlace();
auto* begin = tensor_->mutable_data<T>(cpu);
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
}
framework::Tensor* tensor_;
float value_;
};
template <>
void set_constant_with_place<platform::CPUPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(tensor, value));
}
struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
TensorSetConstantWithPlace(const platform::DeviceContext& context,
framework::Tensor* tensor, float value)
: context_(context), tensor_(tensor), value_(value) {}
template <typename Place>
void operator()(Place place) const {
set_constant_with_place<Place>(context_, tensor_, value_);
}
const platform::DeviceContext& context_;
framework::Tensor* tensor_;
float value_;
};
void set_constant(const platform::DeviceContext& context,
framework::Tensor* tensor, float value) {
TensorSetConstantWithPlace func(context, tensor, value);
#ifdef PADDLE_WITH_CUDA
tensor->place().apply_visitor(func);
#else
func(platform::CPUPlace());
#endif
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_type.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
......@@ -232,6 +233,30 @@ void gemv<platform::GPUPlace, double>(const platform::DeviceContext& context,
template struct SetConstant<platform::GPUPlace, float>;
struct TensorSetConstant {
TensorSetConstant(const platform::DeviceContext& context,
framework::Tensor* tensor, float value)
: context_(context), tensor_(tensor), value_(value) {}
template <typename T>
void operator()() const {
SetConstant<platform::GPUPlace, T> functor;
functor(context_, tensor_, static_cast<T>(value_));
}
const platform::DeviceContext& context_;
framework::Tensor* tensor_;
float value_;
};
template <>
void set_constant_with_place<platform::GPUPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(context, tensor, value));
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -108,6 +108,13 @@ struct SetConstant {
}
};
template <typename Place>
void set_constant_with_place(const platform::DeviceContext& context,
framework::Tensor* tensor, float value);
void set_constant(const platform::DeviceContext& context,
framework::Tensor* tensor, float value);
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -139,3 +139,15 @@ TEST(math_function, gemv) {
GemvTest<float>(12, 7, true);
GemvTest<double>(7, 9, true);
}
TEST(math_funciton, set_constant) {
paddle::framework::Tensor t;
t.Resize({10, 10});
t.mutable_data<int>(paddle::platform::CPUPlace());
auto* ctx = new paddle::platform::CPUDeviceContext();
paddle::operators::math::set_constant(*ctx, &t, 10);
for (int64_t i = 0; i < t.numel(); ++i) {
PADDLE_ENFORCE_EQ(10, t.data<int>()[i]);
}
delete ctx;
}
......@@ -51,9 +51,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
}
};
......@@ -107,9 +109,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
}
};
......
......@@ -85,9 +85,11 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Score")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
ctx.device_context());
}
};
......
......@@ -80,9 +80,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
ctx.device_context());
}
};
......
......@@ -49,9 +49,11 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
ctx.device_context());
}
};
......@@ -66,9 +68,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
ctx.device_context());
}
};
......
......@@ -107,9 +107,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -121,9 +121,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
......@@ -160,10 +162,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type());
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type()),
ctx.device_context());
}
};
......
......@@ -47,20 +47,24 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::LoDTensor>().type());
return framework::OpKernelType(
framework::ToDataType(x_vars[0]->Get<framework::LoDTensor>().type()),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type());
return framework::OpKernelType(
framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type()),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
auto& array = x_vars[0]->Get<framework::LoDTensorArray>();
for (auto& each : array) {
if (each.numel() != 0) {
return framework::ToDataType(each.type());
return framework::OpKernelType(framework::ToDataType(each.type()),
ctx.device_context());
}
}
}
......
......@@ -63,9 +63,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
......@@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}
void CUDADeviceContext::Finish() const {
Wait();
PADDLE_ENFORCE(cudaGetLastError());
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
......
......@@ -46,6 +46,8 @@ class DeviceContext {
DeviceType* GetEigenDevice() const;
virtual void Wait() const {}
virtual void Finish() const {}
};
class CPUDeviceContext : public DeviceContext {
......@@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */
void Wait() const override;
/*! \brief Check potential errors for the cuda kernel calls. */
void Finish() const override;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......
......@@ -113,11 +113,13 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>)
.def("set", PyCPUTensorSetFromArray<int64_t>)
.def("set", PyCPUTensorSetFromArray<bool>)
#ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>)
.def("set", PyCUDATensorSetFromArray<double>)
.def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>)
#endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>)
......
......@@ -85,7 +85,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
} // namespace details
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t>()(
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool>()(
tensor);
return buffer_info;
}
......
......@@ -6548,26 +6548,27 @@ def switch_order_layer(input,
@layer_support()
def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
"""
This layer crops images by offset and shape. User can set crop shape by
args 'shape' explicitly or by reference input layer.
This layer crops images according to the offset and shape. Users can set
the crop shape through the argument 'shape' explicitly or by specifying a
reference input layer.
The example usage is:
.. code-block:: python
crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3])
:param input: The input of this layer. If two inputs are given, the second input
will be regarded as reference input.
:param input: The input of this layer. If two inputs are given, the second one
will be regarded as the reference.
:type input: LayerOutput | Sequence
:param offset: The crop offset.
:type offset: Sequence
:param axis: start axis to be cropped. To image input layer:
:param axis: The start axis to be cropped. For image input layer:
- 0: batch size
- 1: channels
- 2: height
- 3: width
:type partial_sum: int
:param shape: The shape to be cropped. Default is None.
:type axis: int
:param shape: The shape to be cropped to. Default is None.
:type shape: Sequence | None
:param name: The name of this layer. It is optional.
:type name: basestring
......@@ -6702,9 +6703,9 @@ def seq_slice_layer(input, starts, ends, name=None):
:type name: basestring
:param input: The input of this layer, which should be a sequence.
:type input: LayerOutput
:param starts: start indices to slice the input sequence.
:param starts: The start indices to slice the input sequence.
:type starts: LayerOutput | None
:param ends: end indices to slice the input sequence.
:param ends: The end indices to slice the input sequence.
:type ends: LayerOutput | None
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -6744,7 +6745,7 @@ def seq_slice_layer(input, starts, ends, name=None):
@layer_support()
def kmax_seq_score_layer(input, name=None, beam_size=1):
"""
This layer accepts one input which are scores over a sequence or a nested
This layer accepts one input which is scores over a sequence or a nested
sequence, and returns indices of beam_size sequences with highest scores.
.. code-block:: python
......@@ -6754,11 +6755,11 @@ def kmax_seq_score_layer(input, name=None, beam_size=1):
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: The input of this layer. It stores scores over a sequence or a nested
sequence and its size must be 1.
:param input: The input of this layer. It stores scores over a sequence or
a nested sequence and its size must be 1.
:type input: LayerOutput
:param beam_size: sequence indices with top beam_size scores are returned.
:type beam_size: double
:param beam_size: The indices of the sequences with top beam_size scores are returned.
:type beam_size: int
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -6814,38 +6815,42 @@ def img_conv3d_layer(input,
:type name: basestring
:param input: The input of this layer.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a list.
:param filter_size: The dimensions of the filter kernel along three axises. If the parameter
is set to one integer, the three dimensions will be same.
:type filter_size: int | tuple | list
:param num_filters: Each filter group's number of filter
:param num_filters: The number of filters in each group.
:type num_filters: int
:param act: Activation type. ReluActivation is the default.
:type act: BaseActivation
:param groups: Group size of filters.
:param groups: The number of the filter groups.
:type groups: int
:param stride: The x dimension of the stride. Or input a tuple for two image
dimension.
:param stride: The strides of the convolution along three axises. If the parameter
is set to one integer, the three strides will be same.
:type stride: int | tuple | list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension
:param padding: The numbers of padding along three axises. If the parameter is set to
one integer, they will be same.
:type padding: int | tuple | list
:param bias_attr: Convolution bias attribute. None means default bias.
False means no bias.
:param bias_attr: The Bias Attribute. If the parameter is set to
False or something not type of ParameterAttribute,
no bias is defined. If the parameter is set to
True, the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param num_channels: number of input channels. If None will be set
automatically from previous output.
:param num_channels: The number of input channels. If the parameter is not set or
set to None, its actual value will be automatically set to
the channels number of the input .
:type num_channels: int
:param param_attr: Convolution param attribute. None means default attribute
:param param_attr: The parameter attribute of the convolution.
:type param_attr: ParameterAttribute
:param shared_biases: Is biases will be shared between filters or not.
:param shared_biases: Whether biases will be shared between filters or not.
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:param layer_attr: Extra layer attributes.
:type layer_attr: ExtraLayerAttribute
:param trans: true if it is a convTransLayer, false if it is a convLayer
:param trans: True if it is a convTransLayer, False if it is a convLayer
:type trans: bool
:param layer_type: specify the layer_type, default is None. If trans=True,
layer_type has to be "exconvt" or "cudnn_convt",
otherwise layer_type has to be either "exconv" or
"cudnn_conv"
:type layer_type: String
:param layer_type: Specify the layer_type. If the parameter is set, it must be "deconv3d"
when trans=True. If not set, it will be automatically set to "deconv3d"
when trans=True and "conv3d" when trans=False.
:type layer_type: basestring
:return: LayerOutput object.
:rtype: LayerOutput
"""
......@@ -6927,7 +6932,7 @@ def img_conv3d_layer(input,
def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None):
"""
A layer applies a linear transformation to each element in each row of
the input matrix. For each element, the layer first re-scale it and then
the input matrix. For each element, the layer first re-scales it and then
adds a bias to it.
This layer is very like the SlopeInterceptLayer, except the scale and
......@@ -7001,12 +7006,12 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None):
:type name: basestring
:param input: The input of this layer, which should be sequence.
:type input: LayerOutput
:param offsets: offset indices to slice the input sequence, which should be
sequence type.
:param offsets: The offset indices to slice the input sequence, which should
be sequence type.
:type offsets: LayerOutput
:param sizes: sizes of the sub-sequences, which should be sequence type.
:param sizes: The sizes of the sub-sequences, which should be sequence type.
:type sizes: LayerOutput
:param act: Layer activation, default is LinearActivation
:param act: Activation type, LinearActivation is the default.
:type act: BaseActivation.
:param bias_attr: The Bias Attribute. If the parameter is set to
False or something not type of ParameterAttribute,
......
......@@ -22,6 +22,7 @@ parse training set and test set into paddle reader creators.
import numpy as np
import os
import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters
__all__ = ['train', 'test']
......@@ -34,7 +35,8 @@ feature_names = [
UCI_TRAIN_DATA = None
UCI_TEST_DATA = None
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
def feature_range(maximums, minimums):
import matplotlib
......@@ -111,6 +113,13 @@ def test():
return reader
def model():
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', MD5_MODEL)
with open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters
def fetch():
paddle.v2.dataset.common.download(URL, 'uci_housing', MD5)
......
......@@ -775,6 +775,30 @@ def lod_rank_table(x, level=0, main_program=None):
return table
def lod_tensor_to_array(x, table, main_program=None):
helper = LayerHelper("lod_tensor_to_array", **locals())
array = helper.create_variable(
name=unique_name("lod_tensor_to_array"),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
helper.append_op(
type='lod_tensor_to_array',
inputs={'X': x,
'RankTable': table},
outputs={'Out': array})
return array
def array_to_lod_tensor(x, table, main_program=None):
helper = LayerHelper("array_to_lod_tensor", **locals())
tmp = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type="array_to_lod_tensor",
inputs={'X': x,
'RankTable': table},
outputs={'Out': tmp})
return tmp
def fill_constant(shape, dtype, value, main_program=None):
helper = LayerHelper("ones", **locals())
out = helper.create_tmp_variable(dtype=dtype)
......
......@@ -26,4 +26,5 @@ class TestAccuracyOp(OpTest):
if __name__ == '__main__':
exit(0)
unittest.main()
import op_test
import unittest
import numpy
def create_test_class(op_type, typename, callback):
class Cls(op_test.OpTest):
def setUp(self):
a = numpy.random.random(size=(10, 7)).astype(typename)
b = numpy.random.random(size=(10, 7)).astype(typename)
c = callback(a, b)
self.inputs = {'X': a, 'Y': b}
self.outputs = {'Out': c}
self.op_type = op_type
def test_output(self):
self.check_output()
cls_name = "{0}_{1}".format(op_type, typename)
Cls.__name__ = cls_name
globals()[cls_name] = Cls
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,6 @@ class TestLoDRankTable(unittest.TestCase):
tensor = core.LoDTensor()
tensor.set(numpy.random.random(size=(17, 100)), cpu)
tensor.set_lod([[0, 1, 3], [0, 5, 6, 7], [0, 3, 4, 9, 10, 13, 16, 17]])
exe.run(g_main_program, scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table()
......
import unittest
import paddle.v2.framework.core as core
import numpy
import paddle.v2.framework.layers as layers
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
class TestCPULoDTensorArrayOps(unittest.TestCase):
def place(self):
return core.CPUPlace()
def test_lod_tensor_to_array_level_0(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(10).reshape(10, 1).astype('int32'), self.place())
tensor.set_lod([[0, 3, 9, 10]])
expect = map(lambda x: numpy.array(x).astype('int32'),
[[3, 0, 9], [4, 1], [5, 2], [6], [7], [8]])
self.main(tensor=tensor, expect_array=expect, expect_lod=[] * 6)
def test_lod_tensor_to_array_level_0_empty_seq(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(10).reshape(10, 1).astype('int32'), self.place())
tensor.set_lod([[0, 3, 9, 9, 10]])
expect = map(lambda x: numpy.array(x).astype('int32'),
[[3, 0, 9], [4, 1], [5, 2], [6], [7], [8]])
self.main(tensor=tensor, expect_array=expect, expect_lod=[] * 6)
def test_lod_tensor_to_array_level_1(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(20).reshape(20, 1).astype('int32'), self.place())
tensor.set_lod([[0, 2, 5], [0, 3, 9, 11, 17, 20]])
expect = [
numpy.array(
[9, 10, 0, 1, 2], dtype='int32'), numpy.array(
[11, 12, 13, 14, 15, 16, 3, 4, 5, 6, 7, 8], dtype='int32'),
numpy.array(
[17, 18, 19], dtype='int32')
]
lod = [[[0, 2, 5]], [[0, 6, 12]], [[0, 3]]]
self.main(tensor=tensor, expect_array=expect, expect_lod=lod)
def test_lod_tensor_to_array_level_1_empty_seq(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(31).reshape(31, 1).astype('int32'), self.place())
tensor.set_lod([[0, 3, 5, 9, 11],
[0, 3, 7, 11, 11, 12, 17, 19, 21, 23, 30, 31]])
expect = [
numpy.array(
item, dtype='int32')
for item in [[
12, 13, 14, 15, 16, 0, 1, 2, 23, 24, 25, 26, 27, 28, 29
], [17, 18, 3, 4, 5, 6, 11, 30], [19, 20, 7, 8, 9, 10], [21, 22]]
]
lod = [[[0, 5, 8, 8, 15]], [[0, 2, 6, 7, 8]], [[0, 2, 6]], [[0, 2]]]
self.main(tensor=tensor, expect_array=expect, expect_lod=lod)
def test_lod_tensor_to_array_level_2(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(50).reshape(50, 1).astype('int32'), self.place())
tensor.set_lod([[0, 2, 5, 6], [0, 2, 5, 6, 10, 12, 13],
[0, 3, 7, 11, 17, 21, 22, 23, 27, 31, 39, 45, 46, 50]])
expect = [
numpy.array(
item, dtype='int32')
for item in [[21, 0, 1, 2, 3, 4, 5, 6, 46, 47, 48, 49], range(
22, 39) + range(7, 21), range(39, 46)]
]
lod = [[[0, 1, 3, 4], [0, 1, 4, 8, 12]],
[[0, 4, 7], [0, 1, 5, 9, 17, 21, 27, 31]], [[0, 2], [0, 6, 7]]]
self.main(tensor=tensor, expect_array=expect, expect_lod=lod)
def test_lod_tensor_to_array_level_2_skip_level(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(50).reshape(50, 1).astype('int32'), self.place())
tensor.set_lod([[0, 2, 5, 6], [0, 2, 5, 6, 10, 12, 13],
[0, 3, 7, 11, 17, 21, 22, 23, 27, 31, 39, 45, 46, 50]])
self.main(tensor=tensor, expect_array=None, expect_lod=None, level=1)
def main(self, tensor, expect_array, expect_lod, level=0):
place = self.place()
program = Program()
x = layers.data(name='x', shape=[10], main_program=program)
x.persistable = True
table = layers.lod_rank_table(x, level=level, main_program=program)
array = layers.lod_tensor_to_array(x, table, main_program=program)
array.persistable = True
result = layers.array_to_lod_tensor(array, table, main_program=program)
result.persistable = True
exe = Executor(place)
scope = core.Scope()
exe.run(program, feed={'x': tensor}, scope=scope)
var = scope.find_var(array.name)
array = var.get_lod_tensor_array()
if expect_array is not None and expect_lod is not None:
self.check_array_same(array, expect_array, expect_lod)
self.check_tensor_same(scope.find_var(result.name).get_tensor(), tensor)
def check_array_same(self, array, expect_tensor, expect_lod):
self.assertEqual(len(expect_tensor), len(array))
for i, exp in enumerate(zip(expect_tensor, expect_lod)):
exp_tensor, exp_lod = exp
exp_tensor = numpy.expand_dims(exp_tensor, axis=1)
self.assertTrue(numpy.allclose(exp_tensor, numpy.array(array[i])))
self.assertEqual(exp_lod, array[i].lod())
def check_tensor_same(self, actual, expect):
self.assertTrue(
numpy.allclose(numpy.array(actual), numpy.array(expect)))
self.assertEqual(actual.lod(), expect.lod())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册