提交 a38bbc86 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

......@@ -463,7 +463,7 @@ function(py_test TARGET_NAME)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
python2 ${py_test_SRCS}
${PYTHON_EXECUTABLE} ${py_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
......@@ -168,17 +168,3 @@ function(create_resources res_file output_file)
COMMAND python ARGS ${PADDLE_SOURCE_DIR}/cmake/make_resource.py ${res_file} ${output_file}
DEPENDS ${res_file} ${PADDLE_SOURCE_DIR}/cmake/make_resource.py)
endfunction()
# Create a python unittest using run_python_tests.sh,
# which takes care of making correct running environment
function(add_python_test TEST_NAME)
foreach(arg ${ARGN})
get_filename_component(py_fn ${arg} NAME_WE)
set(TRG_NAME ${TEST_NAME}_${py_fn})
add_test(NAME ${TRG_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_PACKAGE_DIR}
python2 ${arg}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endforeach()
endfunction()
......@@ -6,7 +6,10 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
......@@ -51,10 +54,6 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor)
cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
......
......@@ -22,7 +22,6 @@
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
namespace paddle {
......@@ -218,21 +217,6 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
return false;
});
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "dynamic_recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::DynamicRecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::DynamicRecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.rnn.GetStepUnit());
// create stepnet's gradient op
rnn_grad_op->rnn.SetStepUnit(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
}
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
......@@ -522,7 +506,7 @@ ParamGradInfoMap AppendBackward(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)},
{"data_type", target.GetDataType()}}));
{"dtype", target.GetDataType()}}));
// infer var type of fill_one_op
fill_one_op->InferVarType(root_block);
......
......@@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(10) << op->DebugString();
VLOG(3) << op->DebugString();
op->Run(*local_scope, *device);
}
if (create_local_scope) {
......
......@@ -24,6 +24,7 @@
#include <glog/logging.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
......@@ -175,9 +176,9 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1);
for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
tensor.Slice(elem, elem + 1)
.CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext());
auto slice = tensor.Slice(elem, elem + 1);
CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext(), &slice);
}
}
return tensor;
......
......@@ -26,6 +26,8 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";
bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
......@@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0);
}
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
int block_id) {
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") {
attr.set_b(true);
break;
}
}
}
}
}
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
}
} // namespace framework
} // namespace paddle
......@@ -22,5 +22,7 @@ namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc* output);
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);
} // namespace framework
} // namespace paddle
......@@ -89,34 +89,6 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */
inline Tensor& ShareDataWith(const Tensor& src);
/**
* @brief Copy the content of external tensor to a new place.
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
* @param[in] ctx The device context contains device resources.
*
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
// TODO(qijun): https://github.com/PaddlePaddle/Paddle/issues/4647
// Remove `CopyFrom` and `CopyFromVector` from Tensor interface
// and make them global functions
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx);
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src,
const platform::DeviceContext& ctx);
/**
* @brief Return a sub-tensor of the given tensor.
*
......@@ -141,7 +113,6 @@ class Tensor {
size_t memory_size() const;
private:
inline void check_memory_size() const;
private:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/tensor_array.h"
#include <glog/logging.h>
#include <algorithm>
#include <limits>
#include "paddle/framework/eigen.h"
namespace paddle {
namespace framework {
namespace detail {
/*
* Offer an iterator over the length-sorted lod-tensor's top level. The top
* level of a lod-tensor stores batch-size of sequences, each top-level sequence
* may contains several lower-level sequences, sort top-level lod by the numbers
* of lower-level sequences in descending order, so that during RNN's running,
* the batch-size will keep decreasing, the short sentences will end at the tail
* of each batch.
*
* Let's take a simple lod-tensor for example
*
* |(0) |(1) top-level has two instances
* ||| ||||| lower-level
*
* sort by lower-level's length
*
* |(1) |(0)
* ||||| |||
*
* when RNN runs, it get 5 batches (equals the number of elements the longest
* sequence has)
*
* |||||
* |||
*
* the first three batches has two elements, the last two elements just has 1
* element each.
*/
struct DynamicBatchUnpacker {
using value_type = float;
DynamicBatchUnpacker(const LoDTensor& source, size_t level,
bool descend = true)
: source(&source), level(level) {
BuildLengthSortedMeta(descend);
}
LoDTensor GetBatch(size_t index);
std::vector<DySeqMeta> meta;
LoDTensor const* source;
size_t level;
protected:
void BuildLengthSortedMeta(bool descend);
};
LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level);
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch& meta, int batch_id) {
// collect indice need to copy to the batch
std::vector<size_t> indice;
for (const auto& seq : meta) {
size_t id = seq.begin + batch_id;
if (id >= seq.end) break;
indice.push_back(id);
}
return indice;
}
} // namespace detail
const LoDTensor& TensorArray::Read(size_t index) const {
PADDLE_ENFORCE_LE(index, MAX_SIZE, "index[%d] too large", index);
if (index >= size()) {
values_.resize(index + 1);
}
return values_[index];
}
void TensorArray::Write(size_t index, const LoDTensor& value) {
PADDLE_ENFORCE_LE(index, MAX_SIZE, "index[%d] too large", index);
if (index >= size()) {
values_.resize(index + 1);
}
values_[index].set_lod(value.lod());
values_[index].Resize(value.dims());
values_[index].mutable_data<value_type>(value.place());
values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext());
}
void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
PADDLE_ENFORCE_LE(index, MAX_SIZE, "index[%d] too large", index);
if (index >= size()) {
values_.resize(index + 1);
}
values_[index].set_lod(value.lod());
values_[index].ShareDataWith(value);
}
LoDTensor TensorArray::Pack(size_t level, const std::vector<DySeqMeta>& meta,
const LoD& lod) const {
return detail::PackDynamicBatch(values_, meta, lod, level);
}
DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level,
bool length_desend) {
detail::DynamicBatchUnpacker unpacker(source, level,
length_desend /*descend*/);
// find max length of all the sequences
size_t max_length = 0;
for (const auto& seq : unpacker.meta) {
max_length = std::max(max_length, seq.end - seq.begin);
}
// write batches to values
for (size_t batch_id = 0; batch_id < max_length; batch_id++) {
Write(batch_id, unpacker.GetBatch(batch_id));
}
PADDLE_ENFORCE(!unpacker.meta.empty());
return unpacker.meta;
}
LoDTensor TensorArray::LodPack(size_t level) const {
PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists");
// the levels should be no less than 2
LoDTensor merged;
const LoDTensor *pre, *cur;
pre = &Read(0);
for (size_t step = 1; step < size(); step++) {
cur = &Read(step);
PADDLE_ENFORCE_GT(cur->NumLevels(), 0);
PADDLE_ENFORCE_GT(pre->NumLevels(), 0);
PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels());
PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level));
merged = LodPackTwo(*pre, *cur, level);
pre = &merged;
}
return merged;
}
/*
* NOTE currently, only the lowest level supports packing.
* The lowest LoD will be changed, while the relative offsets in levels above
* stay unchanged.
*
* previous step : [0] [1] [3]
* current step: [0 1 2] [2 3] []
* packed to
* [0 0] [0 1] [0 2] [1 2] [1 3] [3]
*/
LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
size_t level) const {
PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels());
PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1,
"Only the lowest LoD level supports pack temporarily.");
// calculate the result tensor's shape first
size_t num_instances = 0;
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
if (num_candidates > 0) {
num_instances += num_candidates * (prefix_size + 1);
} else {
num_instances += prefix_size;
}
}
auto res_dims = pre.dims();
res_dims[0] = num_instances;
LoDTensor result;
result.Resize(res_dims);
result.mutable_data<value_type>(cur.place());
Vector<size_t> last_lod_level;
// copy data
size_t index = 0;
last_lod_level.push_back(index);
for (size_t elem = 0; elem < pre.NumElements(level); elem++) {
size_t prefix_size = pre.NumElements(level, elem);
size_t num_candidates = cur.NumElements(level, elem);
// slice the prefix Tensor
LoDTensor prefix = pre;
prefix.ShrinkInLevel(level, elem, elem + 1);
LoDTensor candidate = cur;
if (num_candidates > 0) {
candidate.ShrinkInLevel(level, elem, elem + 1);
} else { // just push prefix
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
last_lod_level.push_back(index);
}
for (size_t candi = 0; candi < num_candidates; candi++) {
// TODO(superjom) support GPU
result.Slice(index, index + prefix_size)
.CopyFrom(prefix, result.place(), platform::CPUDeviceContext());
index += prefix_size;
// copy candidate record
result.Slice(index, index + 1)
.CopyFrom(candidate.Slice(candi, candi + 1), result.place(),
platform::CPUDeviceContext());
index++;
last_lod_level.push_back(index);
}
}
// update lod
auto lod = cur.lod();
lod.back() = last_lod_level;
result.set_lod(lod);
return result;
}
/*
* source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such
* as
* [0 3 5] [1 4 6] [2 7] with 1-level LoDs:
* - [0 1 2 3]
* - [0 1 2 3]
* - [0 1 1 2], the [1,1) here means the second sequence is empty
*
* NOTE Unpack a LoDTensor in this approach may result in a big LoD.
*/
void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
"only the lowest LoD level supports unpack.");
const size_t non_empty_instances = source.dims()[0];
size_t index = 0;
Vector<size_t> lowest_lod_level;
lowest_lod_level.push_back(index);
for (size_t step = 0; step < non_empty_instances; step++) {
size_t num_instances = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
num_instances++;
index++;
}
lowest_lod_level.push_back(index);
}
// create tensor for this time step
LoDTensor tensor;
auto dims = source.dims();
dims[0] = num_instances;
// set lod
auto lod = source.lod();
lod.back() = lowest_lod_level;
tensor.set_lod(lod);
index = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
instance.ShrinkInLevel(level, id, id + 1);
if (static_cast<size_t>(instance.dims()[0]) > step) {
// copy this instance
tensor.Slice(index, index + 1)
.CopyFrom(instance.Slice(step, step + 1), tensor.place(),
platform::CPUDeviceContext());
index++;
}
}
Write(step, tensor);
}
}
LoDTensor TensorArray::Stack() const {
LoDTensor result;
if (size() == 0) return result;
const auto& first_dims = values_.front().dims();
// check all the values have the same shape
// TODO(superjom) check the same dtypes
for (size_t idx = 1; idx < size(); idx++) {
const auto& value_dims = values_[idx].dims();
PADDLE_ENFORCE_EQ(first_dims, value_dims);
}
// copy
auto result_dims = vectorize(first_dims);
result_dims.insert(result_dims.begin(), size());
result.Resize(make_ddim(result_dims));
result.mutable_data<value_type>(platform::CPUPlace());
for (size_t idx = 0; idx < size(); idx++) {
result.Slice(idx, idx + 1)
.CopyFrom(Read(idx), platform::CPUPlace(),
platform::CPUDeviceContext());
}
return result;
}
void TensorArray::Unstack(const LoDTensor& source) const {
Unstack(source, false /*data_shared*/);
}
void TensorArray::UnstackShared(const LoDTensor& source) const {
Unstack(source, true /*data_shared*/);
}
void TensorArray::Unstack(const LoDTensor& source, bool data_shared) const {
size_t first_dim = source.dims()[0];
DDim value_dims = slice_ddim(source.dims(), 1, source.dims().size());
PADDLE_ENFORCE_GT(first_dim, 0,
"source should have some data to be unstacked");
values_.resize(first_dim);
for (size_t elem = 0; elem < first_dim; elem++) {
// create a new value
auto& value = values_[elem];
if (data_shared) {
// share memory
value.ShareDataWith(source.Slice(elem, elem + 1));
} else {
// copy
value.Resize(value_dims);
value.CopyFrom(source.Slice(elem, elem + 1), platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
}
size_t TensorArray::size() const { return values_.size(); }
namespace detail {
void DynamicBatchUnpacker::BuildLengthSortedMeta(bool descend) {
PADDLE_ENFORCE(meta.empty(), "duplicate build meta");
// collect meta for each sequence in some level
auto lod = SliceLevels(source->lod(), level, level + 1)[0];
for (size_t seq_id = 0; seq_id < lod.size() - 1; seq_id++) {
DySeqMeta seq_meta({lod[seq_id], lod[seq_id + 1], seq_id});
meta.push_back(seq_meta);
}
PADDLE_ENFORCE_GT(meta.size(), 0, "meta is empty");
// sort by length
sort(meta.begin(), meta.end(),
[descend](const DySeqMeta& a, const DySeqMeta& b) {
bool a_ge_b = (a.end - a.begin) > (b.end - b.begin);
return descend ? a_ge_b : !a_ge_b;
});
}
LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
PADDLE_ENFORCE(!meta.empty(), "should build meta first");
LoDTensor result;
auto indice = detail::GenDyBatchIndice(meta, index);
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
// copy the indice of records in LoDTensor
auto record_dims = slice_ddim(source->dims(), 1, source->dims().size());
auto record_dims_vec = vectorize(record_dims);
record_dims_vec.insert(record_dims_vec.begin(), indice.size());
result.Resize(make_ddim(record_dims_vec));
result.mutable_data<value_type>(platform::CPUPlace());
for (size_t i = 0; i < indice.size(); i++) {
auto index = indice[i];
auto target = result.Slice(i, i + 1);
auto slice = source->Slice(index, index + 1);
target.CopyFrom(slice, platform::CPUPlace(), platform::CPUDeviceContext());
}
return result;
}
// TODO(supejom) to cache lod if reasonable
LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level) {
PADDLE_ENFORCE(!source.empty());
PADDLE_ENFORCE(!meta.empty());
PADDLE_ENFORCE(!lod.empty());
LoDTensor result;
// init result space
auto record_dims = slice_ddim(source[0].dims(), 1, source[0].dims().size());
auto record_dims_vec = vectorize(record_dims);
auto height = lod[level].back();
record_dims_vec.insert(record_dims_vec.begin(), height);
result.Resize(make_ddim(record_dims_vec));
result.mutable_data<float>(platform::CPUPlace());
for (size_t batch_id = 0; batch_id < source.size(); batch_id++) {
for (size_t seq_id = 0; seq_id < meta.size(); seq_id++) {
const auto& seq_meta = meta[seq_id];
// source is source[batch_id][seq_id]
// target is result[index]
auto index = seq_meta.begin + batch_id;
if (index >= seq_meta.end) break;
auto source_ = source[batch_id].Slice(seq_id, seq_id + 1);
auto target = result.Slice(index, index + 1);
target.CopyFrom(source_, platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
result.set_lod(lod);
return result;
}
} // namespace detail
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/framework/lod_tensor.h"
namespace paddle {
namespace framework {
/*
* DyBatchSeqPosition stores indices of the basic element in tensor. It is used
* after lod-tensor's re-assembling, its info can be used to recover the order
* in original lod-tensor.
*/
struct DySeqMeta {
DySeqMeta(size_t begin, size_t end, size_t ori_idx)
: begin(begin), end(end), ori_idx(ori_idx) {}
size_t begin;
size_t end; // not included
size_t ori_idx;
};
using DySeqMetaBatch = std::vector<DySeqMeta>;
/*
* Extract the indices of instances.
*/
std::vector<size_t> GenDyBatchIndice(const DySeqMetaBatch &metas, int batch_id);
/*
* TensorArray is a C-array-like array of tensors, it is meant to be used with
* dynamic iteration primitives such as while_loop. It is used to segment inputs
* and store states in all time steps.
*
* By providing some methods similar to a C++ array, the difinition of some
* state-based dynamic models such as RNN cound be more natural and highly
* flexible.
*/
class TensorArray {
public:
using value_type = float;
// max number of values allowed to store.
const size_t MAX_SIZE{100000};
/*
* Read the value at location `index` in the `TensorArray`.
*/
const LoDTensor &Read(size_t index) const;
/*
* Write value into the index of the TensorArray.
*/
void Write(size_t index, const LoDTensor &value);
/*
* Write value into the index of the TensorArray, with memory shared.
*/
void WriteShared(size_t index, const LoDTensor &value);
/*
* Recover the original LoD-arranged LoDTensor with the `values`, `level` and
* `indice_map`.
*/
LoDTensor Pack(size_t level, const DySeqMetaBatch &meta,
const LoD &lod) const;
/*
* Split LoDTensor in some `level` and write the generated batches to
* `values`, if set `desend`, will sort by length in descending order else in
* ascending order.
*/
DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend);
/*
* Pack an array of LoDTensors to a LoDTensor.
*/
LoDTensor LodPack(size_t level) const;
/*
* Unpack a LoDTensor to an array of LoDTensors.
*/
void LodUnpack(const LoDTensor &source, size_t level);
/*
* Pack the values into a tensor with rank one higher than each tensor in
* values.
*/
LoDTensor Stack() const;
/*
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/
void Unstack(const LoDTensor &source) const;
/*
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared.
*/
void UnstackShared(const LoDTensor &source) const;
/*
* Return the number of values.
*/
size_t size() const;
protected:
void Unstack(const LoDTensor &source, bool data_shared) const;
LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur,
size_t level) const;
private:
mutable std::vector<LoDTensor> values_;
}; // class TensorArray
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/tensor_array.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
class TensorArrayTester : public ::testing::Test {
protected:
void SetUp() override {
LoDTensor source;
source.Resize(make_ddim({batch_size, dim}));
int* data = source.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 16 * 32; i++) {
data[i] = i;
}
ta.Unstack(source);
}
TensorArray ta;
const int batch_size = 16;
const int dim = 32;
};
TEST_F(TensorArrayTester, Read) {
for (int i = 0; i < batch_size; i++) {
const auto& tensor = ta.Read(i);
ASSERT_EQ(tensor.dims()[0], 1);
ASSERT_EQ(tensor.dims()[1], dim);
}
}
TEST_F(TensorArrayTester, Write) {
LoDTensor source;
source.Resize(make_ddim({1, dim}));
for (int i = 0; i < dim; i++) {
*(source.mutable_data<int>(platform::CPUPlace()) + i) = i;
}
ta.Write(2, source);
const auto& tensor = ta.Read(2);
for (int i = 0; i < dim; i++) {
EXPECT_EQ(*(tensor.data<int>() + i), *(source.data<int>() + i));
}
}
TEST_F(TensorArrayTester, WriteShared) {
LoDTensor source;
source.Resize(make_ddim({1, dim}));
for (int i = 0; i < dim; i++) {
*(source.mutable_data<int>(platform::CPUPlace()) + i) = i;
}
ta.WriteShared(2, source);
const auto& tensor = ta.Read(2);
for (int i = 0; i < dim; i++) {
EXPECT_EQ(*(tensor.data<int>() + i), *(source.data<int>() + i));
}
EXPECT_EQ(source.data<int>(), tensor.data<int>());
}
class TensorArrayPackTester : public ::testing::Test {
protected:
virtual void SetUp() override {
lod.push_back(std::vector<size_t>{0, 2, 9, 13});
source.set_lod(lod);
source.Resize(make_ddim({13, 128}));
source.mutable_data<int>(platform::CPUPlace());
// content of each setence: 0 1 2 3 4
const auto& level = lod.front();
for (size_t i = 0; i < level.size() - 1; i++) {
size_t begin = level[i];
size_t end = level[i + 1];
for (size_t j = begin; j < end; j++) {
auto record = source.Slice(j, j + 1);
for (int dim = 0; dim < 128; dim++) {
record.mutable_data<int>(platform::CPUPlace())[dim] = j - begin;
}
}
}
// unpack
meta = ta.Unpack(source, 0, true);
}
LoD lod;
TensorArray ta;
LoDTensor source;
std::vector<DySeqMeta> meta;
};
TEST_F(TensorArrayPackTester, Unpack) {
ASSERT_EQ(ta.size(), 7UL);
const auto& t0 = ta.Read(0);
const auto& t1 = ta.Read(1);
ASSERT_EQ(t0.data<int>()[0], int(0));
ASSERT_EQ(t1.data<int>()[0], int(1));
}
TEST_F(TensorArrayPackTester, Pack) {
LoDTensor packed = ta.Pack(0, meta, lod);
}
TEST_F(TensorArrayTester, size) {
ASSERT_EQ(ta.size(), static_cast<size_t>(batch_size));
}
TEST(TensorArray, LodPack) {
// three time steps, each step stores a LoDTensors
// - [0] [1]
// - [2 3], [4 5]
// - [6 7] [] [8], [9, 10]
// try to get a LoDTensor with content:
// - [0 2 6]
// - [0 2 7]
// - [0 3]
// - [1 4 8]
// - [1 5 9]
// - [1 5 10]
std::array<LoDTensor, 3> tensors;
tensors[0].Resize(make_ddim({2, 1}));
tensors[1].Resize(make_ddim({4, 1}));
tensors[2].Resize(make_ddim({5, 1}));
int index = 0;
for (auto& t : tensors) {
t.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < t.dims()[0]; i++) {
t.data<int>()[i] = index;
index++;
}
}
std::array<LoD, 3> lods;
std::vector<std::vector<size_t>> levels{
{0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}};
for (int i = 0; i < 3; i++) {
lods[i].emplace_back(levels[i].begin(), levels[i].end());
}
TensorArray ta;
for (int i = 0; i < 3; i++) {
tensors[i].set_lod(lods[i]);
ta.Write(i, tensors[i]);
}
auto merged = ta.LodPack(0);
std::vector<int> target_tensor_data{{0, 2, 6, // 0
0, 2, 7, // 1
0, 3, // 2
1, 4, 8, // 3
1, 5, 9, // 5
1, 5, 10}};
EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size());
for (size_t i = 0; i < target_tensor_data.size(); i++) {
EXPECT_EQ(target_tensor_data[i], merged.data<int>()[i]);
}
}
} // namespace framework
} // namespace paddle
......@@ -150,84 +150,6 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
return *this;
}
inline void Tensor::CopyFrom(const Tensor& src,
const platform::Place& dst_place,
const platform::DeviceContext& ctx) {
src.check_memory_size();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = src.data<void>();
auto dst_ptr = mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
template <typename T>
inline void Tensor::CopyFromVector(const std::vector<T>& src,
const platform::DeviceContext& ctx) {
auto dst_place = ctx.GetPlace();
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = src.size() * sizeof(T);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) {
memory::Copy(
boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
......
......@@ -188,178 +188,6 @@ TEST(Tensor, Slice) {
#endif
}
TEST(Tensor, CopyFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
CPUDeviceContext cpu_ctx((CPUPlace()));
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom(src_tensor, *cpu_place, cpu_ctx);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice(1, 2);
dst_tensor.CopyFrom(slice_tensor, *cpu_place, cpu_ctx);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_place = new paddle::platform::GPUPlace(0);
CUDADeviceContext gpu_ctx(*gpu_place);
gpu_tensor.CopyFrom(src_tensor, *gpu_place, gpu_ctx);
// GPU Tensor to CPU Tensor
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx);
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice(1, 2);
// CPU Slice Tensor to GPU Tensor
gpu_tensor.CopyFrom(slice_tensor, *gpu_place, gpu_ctx);
// GPU Tensor to CPU Tensor
dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx);
// Sync before Compare Slice Tensors
gpu_ctx.Wait();
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#endif
}
TEST(Tensor, CopyFromVector) {
using namespace paddle::framework;
using namespace paddle::platform;
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
const int* src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
delete cpu_place;
}
#ifdef PADDLE_WITH_CUDA
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace();
CUDADeviceContext gpu_ctx(*gpu_place);
gpu_tensor.CopyFromVector<int>(src_vec, gpu_ctx);
// Copy from GPU to CPU tensor for comparison
dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx);
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
gpu_tensor.Resize(make_ddim({2, 2}));
gpu_tensor.CopyFromVector<int>(src_vec, gpu_ctx);
dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx);
// Sync before Compare Tensors
gpu_ctx.Wait();
src_ptr = src_vec.data();
cpu_ptr = cpu_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
delete cpu_place;
delete gpu_place;
}
#endif
}
TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
namespace paddle {
namespace framework {
/**
* @brief Copy the content of external tensor to a new place.
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
* @param[in] ctx The device context contains device resources.
*
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src,
const platform::DeviceContext& ctx, Tensor* dst) {
auto dst_place = ctx.GetPlace();
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>(dst_place));
auto size = src.size() * sizeof(T);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) { // NOLINT
memory::Copy(
boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
/**
* @brief Copy the content of a tensor to a vector
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T>
inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
std::vector<T>* dst) {
auto src_ptr = static_cast<const void*>(src.data<T>());
auto size = src.numel() * sizeof(T);
platform::CPUPlace dst_place;
dst->resize(src.numel());
auto dst_ptr = static_cast<void*>(dst->data());
if (platform::is_cpu_place(src.place())) {
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()),
src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()), src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
} // namespace framework
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h>
#include <string>
namespace paddle {
namespace framework {
TEST(CopyFrom, Tensor) {
Tensor src_tensor;
Tensor dst_tensor;
platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
int* src_ptr =
src_tensor.mutable_data<int>(make_ddim({3, 3}), platform::CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, cpu_ctx, &dst_tensor);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice(1, 2);
CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
int* src_ptr =
src_tensor.mutable_data<int>(make_ddim({3, 3}), platform::CPUPlace());
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_place = new platform::GPUPlace(0);
platform::CUDADeviceContext gpu_ctx(*gpu_place);
CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
// GPU Tensor to CPU Tensor
auto cpu_place = new platform::CPUPlace();
CopyFrom(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
Tensor slice_tensor = src_tensor.Slice(1, 2);
// CPU Slice Tensor to GPU Tensor
CopyFrom(slice_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
// GPU Tensor to CPU Tensor
CopyFrom(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Slice Tensors
gpu_ctx.Wait();
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#endif
}
TEST(CopyFromVector, Tensor) {
using namespace paddle::framework;
using namespace paddle::platform;
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place);
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
const int* src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
delete cpu_place;
}
#ifdef PADDLE_WITH_CUDA
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place);
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace();
CUDADeviceContext gpu_ctx(*gpu_place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
// Copy from GPU to CPU tensor for comparison
CopyFrom(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
gpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
CopyFrom(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors
gpu_ctx.Wait();
src_ptr = src_vec.data();
cpu_ptr = cpu_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
delete cpu_place;
delete gpu_place;
}
#endif
}
TEST(CopyToVector, Tensor) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src;
int* src_ptr = src.mutable_data<int>({3, 3}, CPUPlace());
for (int i = 0; i < 3 * 3; ++i) {
src_ptr[i] = i;
}
CPUPlace place;
CPUDeviceContext cpu_ctx(place);
std::vector<int> dst;
CopyToVector<int>(src, cpu_ctx, &dst);
for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_ptr[i], dst[i]);
}
}
#ifdef PADDLE_WITH_CUDA
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor gpu_tensor;
GPUPlace place;
CUDADeviceContext gpu_ctx(place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
std::vector<int> dst;
CopyToVector<int>(gpu_tensor, gpu_ctx, &dst);
for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_vec[i], dst[i]);
}
}
#endif
}
} // namespace framework
} // namespace paddle
add_subdirectory(detail)
cc_library(memory SRCS memory.cc DEPS place)
cc_library(memory SRCS memory.cc DEPS place enforce)
cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory
......
......@@ -178,7 +178,6 @@ set(DEPS_OPS
cond_op
cross_entropy_op
recurrent_op
dynamic_recurrent_op
softmax_with_cross_entropy_op
softmax_op
sequence_softmax_op
......@@ -227,13 +226,6 @@ op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
if(WITH_TESTING)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array gtest)
else()
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array)
endif()
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......@@ -248,9 +240,6 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc
DEPS dynamic_recurrent_op)
if(WITH_GPU)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
......
......@@ -223,6 +223,51 @@ $y = |x|$
}
};
class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CeilOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator");
AddComment(R"DOC(
Ceil Activation Operator.
$y = ceil(x)$
)DOC");
}
};
class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FloorOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator");
AddComment(R"DOC(
Floor Activation Operator.
$y = floor(x)$
)DOC");
}
};
class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RoundOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator");
AddComment(R"DOC(
Round Activation Operator.
$y = [x]$
)DOC");
}
};
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReciprocalOpMaker(framework::OpProto *proto,
......@@ -493,6 +538,15 @@ REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
ops::ActivationOpGrad);
REGISTER_OP(floor, ops::ActivationOp, ops::FloorOpMaker, floor_grad,
ops::ActivationOpGrad);
REGISTER_OP(round, ops::ActivationOp, ops::RoundOpMaker, round_grad,
ops::ActivationOpGrad);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad);
......
......@@ -283,6 +283,41 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
}
};
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
}
};
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = static_cast<T>(0) / x;
}
};
// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.round();
}
};
// abs(x) = |x|
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
......@@ -677,6 +712,9 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
......
......@@ -36,7 +36,7 @@ class ArrayOp : public framework::OperatorBase {
if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
framework::CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx, &t);
dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>());
} else {
......
......@@ -102,8 +102,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
if (len == 0) {
continue;
}
out->Slice(out_offset, out_offset + len)
.CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx);
auto slice = out->Slice(out_offset, out_offset + len);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice);
out_offset += len;
}
}
......
......@@ -43,7 +43,8 @@ class AssignFunctor {
out_rows.set_rows(rows.rows());
out_rows.set_height(rows.height());
auto &t = rows.value();
out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_);
auto *m = out_rows.mutable_value();
framework::CopyFrom(t, t.place(), dev_ctx_, m);
}
template <typename T>
......@@ -55,7 +56,7 @@ class AssignFunctor {
void copy_tensor(const framework::LoDTensor &lod_tensor,
framework::LoDTensor *out) const {
auto &out_tensor = *out;
out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_);
CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
}
......
......@@ -17,6 +17,36 @@ limitations under the License. */
namespace paddle {
namespace operators {
struct BeamSearchDecodeFunctor {
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor, LoDTensor* score_tensor)
: step_ids_(step_ids),
step_scores_(step_scores),
id_tensor_(id_tensor),
score_tensor_(score_tensor) {}
template <typename T>
void operator()() const;
const LoDTensorArray& step_ids_;
const LoDTensorArray& step_scores_;
LoDTensor* id_tensor_;
LoDTensor* score_tensor_;
};
template <typename T>
void BeamSearchDecodeFunctor::operator()() const {
BeamSearchDecoder<T> beam_search_decoder;
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
score_tensor_);
}
template <>
void BeamSearchDecodeFunctor::operator()<bool>() const {
PADDLE_THROW("beam search decode op does not support bool!");
}
class BeamSearchDecodeOp : public framework::OperatorBase {
public:
BeamSearchDecodeOp(const std::string& type,
......@@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
BeamSearchDecoder<float> beam_search_decoder;
beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds,
sentenceScores);
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
}
};
......
......@@ -232,12 +232,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace());
id_tensor->CopyFromVector<int64_t>(id_data, cpu_ctx);
framework::CopyFromVector<int64_t>(id_data, cpu_ctx, id_tensor);
score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
score_tensor->mutable_data<T>(paddle::platform::CPUPlace());
score_tensor->CopyFromVector<T>(score_data, cpu_ctx);
framework::CopyFromVector<T>(score_data, cpu_ctx, score_tensor);
}
template <typename T>
......
......@@ -77,11 +77,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output of bilinear_tensor_product operator.");
AddComment(R"DOC(
Bilinear Tensor Product operator.
Given input X and Y, a 3D tensor weight, and bias. Each column of the
output is computed by one slice i = 1, . . . , k of the tensor:
M = (X W_i) \cdot Y
Out_i = \sum_i {M_i} + Bias_i
Given input X and Y, a 3D tensor Weight and a Bias. Each column of the
Output is computed by one slice $i = 1, . . . , k$ of the tensor:
$$
M = (X W_i) * Y \\
Out_i = \sum_j {M_j} + Bias_i
$$
Where $W_i$ is the $i$-th slice of Input(Weight);
$M_j$ is the $j$-th column of $M$;
$Out_i$ is the $i$-th column of Output(Out);
$Bias_i$ is a column vector, each element of it is equal to
the $i$-th element of $Bias$;
)DOC");
}
......
......@@ -25,8 +25,8 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_data_type", "output data type");
AddAttr<int>("in_data_type", "input data type");
AddAttr<int>("out_dtype", "output data type");
AddAttr<int>("in_dtype", "input data type");
AddComment(R"DOC(
Cast Operator.
......@@ -58,8 +58,8 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X"));
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
grad->SetAttr("out_dtype", GetAttr("in_dtype"));
grad->SetAttr("in_dtype", GetAttr("out_dtype"));
return std::unique_ptr<framework::OpDescBind>(grad);
}
};
......
......@@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<Place, InT>(in, out, context.device_context()));
}
};
......
......@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
......@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
output_shape_vec.erase(output_shape_vec.begin(),
output_shape_vec.begin() + 2);
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(input->dims()[1] / groups);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
output_shape_vec.end());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w)
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
......@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) {
} else if (data_dim == 2U) {
// im2col
im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
} else if (data_dim == 3U) {
// vol2col
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
......@@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(
framework::vectorize(output_grad->dims()));
output_shape_vec.erase(output_shape_vec.begin(),
output_shape_vec.begin() + 2);
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(input->dims()[1] / groups);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
output_shape_vec.end());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
......@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// or
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
framework::flatten_to_2d(col_shape, data_dim + 1);
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
......@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
if (is_expand && filter_shape_vec.size() == 2) {
if (is_expand && data_dim == 2U) {
col2im(context.device_context(), col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&in_grad_slice);
} else if (is_expand && filter_shape_vec.size() == 3) {
} else if (is_expand && data_dim == 3U) {
col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice);
}
......@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) {
} else if (data_dim == 2U) {
im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
} else if (data_dim == 3U) {
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
}
......
......@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(output->dims()[1]);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
input_shape_vec.end());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = output->dims()[1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
if (filter_shape_vec.size() == 2) {
if (data_dim == 2U) {
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col,
......@@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&output_batch);
} else if (filter_shape_vec.size() == 3) {
} else if (data_dim == 3U) {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), col, dilations, strides, paddings,
......@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(output_grad->dims()[1]);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
input_shape_vec.end());
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = output_grad->dims()[1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
......@@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
if (filter_shape_vec.size() == 2) {
if (data_dim == 2U) {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch,
......@@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
} else if (data_dim == 3U) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, dilations,
......
......@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == true) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "True if in training phase.").SetDefault(true);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddComment(R"DOC(
......@@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
"GradOp is only callable when is_test is false");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
......
......@@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());
......
......@@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) {
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
......@@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(!context.Attr<bool>("is_test"),
"GradOp is only callable when is_test is false");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve .
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
using framework::OperatorBase;
using framework::DySeqMetaBatch;
namespace detail {
inline void CreateVariables(Scope& scope,
const std::vector<std::string>& var_names) {
for (const auto& name : var_names) {
scope.Var(name);
}
}
/*
* The inputs with sequence should be reordered when they are split, so the
* boot_states should be reordered in the same order.
*
* NOTE This may require that the `pre_state` of the first time step should just
* copy the `boot_state` rather than reference it, for that the content should
* be reordered, but the RNN op should not change the `boot_state` as an input
* variable's content.
*/
inline void ReorderInitialState(const DySeqMetaBatch& metas,
const LoDTensor& boot_state, LoDTensor* tensor,
const platform::Place& dst_place) {
for (size_t seq_id = 0; seq_id < metas.size(); seq_id++) {
auto slice = tensor->Slice(seq_id, seq_id + 1);
auto boot_slice =
boot_state.Slice(metas[seq_id].ori_idx, metas[seq_id].ori_idx + 1);
// TODO(superjom) pass in device context as an argument
slice.CopyFrom(boot_slice, dst_place, platform::CPUDeviceContext());
}
}
inline void RestoreInitialState(const DySeqMetaBatch& metas,
const LoDTensor& tensor, LoDTensor* boot_state,
const platform::Place& dst_place) {
for (size_t seq_id = 0; seq_id < metas.size(); seq_id++) {
auto slice = tensor.Slice(seq_id, seq_id + 1);
auto boot_slice =
boot_state->Slice(metas[seq_id].ori_idx, metas[seq_id].ori_idx + 1);
boot_slice.CopyFrom(slice, dst_place, platform::CPUDeviceContext());
}
}
} // namespace detail
// Implementation for forward propagation.
template <>
void RNNAlgorithm::Run<RNNAlgorithm::ComputeMode::kForward>(
const framework::Scope& scope, const framework::OperatorBase& op,
const platform::DeviceContext& dev_ctx) {
SetComputeMode(ComputeMode::kForward);
cache_.Init(kArgNames[mode_], op, scope, &dev_ctx, &arg_);
SplitInputs();
CreateScopes();
WriteStepInputs();
InitStates();
WriteStepOutputs();
RunSteps();
ConcatOutputs();
}
// Implementation for backward propagation.
template <>
void RNNAlgorithm::Run<RNNAlgorithm::ComputeMode::kBackward>(
const framework::Scope& scope, const framework::OperatorBase& op,
const platform::DeviceContext& dev_ctx) {
SetComputeMode(ComputeMode::kBackward);
cache_.Init(kArgNames[mode_], op, scope, &dev_ctx, &arg_);
SplitInputs();
WriteStepInputs();
InitStates();
WriteStepOutputs();
RunSteps();
// copy boot-states' gradients back.
for (const auto& state : arg_.states) {
ExportInitialStateGradient(state);
}
ConcatOutputs();
}
void RNNAlgorithm::SplitInputs() {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
for (const auto& item : cache_.inputs) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
if (cache_.num_steps) {
PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps,
"inputs should have the same steps");
} else {
cache_.num_steps = ta.size();
}
}
}
void RNNAlgorithm::WriteStepInputs() {
for (const auto& item : cache_.inputs) {
auto ta_it = step_inputs_.find(item.first);
PADDLE_ENFORCE(ta_it != step_inputs_.end(),
"step_inputs_ not compatible with memory set");
TensorArray& ta = ta_it->second;
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
auto& step_scope = cache_.GetScope(step);
Variable* var = step_scope.FindVar(item.first);
if (var == nullptr) {
var = step_scope.Var(item.first);
}
var->GetMutable<LoDTensor>()->ShareDataWith(tensor);
}
}
}
void RNNAlgorithm::WriteStepOutputs() {
// initialize step outputs
for (const auto& item : cache_.outputs) {
step_outputs_.emplace(item.first, TensorArray());
}
PADDLE_ENFORCE_GT(step_outputs_.size(), 0UL);
}
void RNNAlgorithm::CreateScopes() {
PADDLE_ENFORCE_GT(cache_.num_steps, 0);
// resize scopes
size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size();
for (size_t i = 0; i < num_scopes_need_create; i++) {
cache_.scopes->emplace_back(&cache_.scope->NewScope());
}
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL(step_unit_, "stepnet should be set first");
std::vector<std::string> states;
std::vector<std::string> ex_states;
std::vector<std::string> step_unit_outputs;
std::transform(arg_.states.begin(), arg_.states.end(),
std::back_inserter(states),
[](const rnn::StateAttr& m) { return m.var; });
std::transform(arg_.states.begin(), arg_.states.end(),
std::back_inserter(ex_states),
[](const rnn::StateAttr& m) { return m.pre_var; });
for (const auto& item : step_unit_->Outputs()) {
for (const auto& var : item.second) {
step_unit_outputs.push_back(var);
}
}
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
detail::CreateVariables(scope, arg_.inlinks);
detail::CreateVariables(scope, arg_.outlinks);
detail::CreateVariables(scope, states);
detail::CreateVariables(scope, ex_states);
detail::CreateVariables(scope, step_unit_outputs);
}
}
void RNNAlgorithm::ConcatOutputs() {
// TODO(superjom) transform this to a config
int level = 0;
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
PADDLE_ENFORCE_NOT_NULL(var);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->mutable_data<value_type>(platform::CPUPlace());
item.second.WriteShared(step, *tensor);
}
}
// the inputs' lods should be the same, so randomly get one lod.
const auto& some_lod =
cache_.scope->FindVar(arg_.inlinks.front())->Get<LoDTensor>().lod();
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, some_meta, some_lod);
auto* output = cache_.outputs[item.first]->GetMutable<LoDTensor>();
const_cast<LoDTensor*>(output)->ShareDataWith(tensor);
}
}
void RNNAlgorithm::RunSteps() {
if (IsBackward()) {
// call stepnet in all the time steps reversely
for (int step = cache_.num_steps - 1; step >= 0; step--) {
auto& step_scope = cache_.GetScope(step);
step_unit_->Run(step_scope, *cache_.dev_ctx);
}
} else {
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& step_scope = cache_.GetScope(step);
step_unit_->Run(step_scope, *cache_.dev_ctx);
}
}
}
void RNNAlgorithm::InitStates() {
for (size_t step = 0; step < cache_.num_steps; step++) {
for (const auto& state : arg_.states) {
CreateState(state, step);
LinkState(state, step);
}
}
}
void RNNAlgorithm::CreateState(const rnn::StateAttr& state_attr, size_t step) {
auto& scope = cache_.GetScope(step);
auto& state = *cache_.GetTensor(scope, state_attr.var);
auto& boot_state = *cache_.GetTensor(*cache_.scope, state_attr.boot_var);
size_t num_instances =
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
auto dims = boot_state.dims();
dims[0] = num_instances;
state.Resize(dims);
state.mutable_data<value_type>(platform::CPUPlace());
states_[state_attr.var].WriteShared(step, state);
}
void RNNAlgorithm::LinkState(const rnn::StateAttr& state, size_t step) {
auto& scope = cache_.GetScope(step);
auto& state_pre = *cache_.GetTensor(scope, state.pre_var);
// process the first state's boot-state(the 0-step in forward mode or the
// last step in backward mode)
// Only forward mode need to link the boot-state to the `pre-state` in first
// time step. In backward mode, need to copy the gradient of `pre-state` in
// first time step to the gradient of `boot-state`.
if (step == 0 && IsForward()) {
LinkInitialState(state);
} else {
size_t num_instances =
step_inputs_[arg_.inlinks.front()].Read(step).dims()[0];
auto* pre_state = cache_.GetTensor(cache_.GetScope(step - 1), state.var);
// shink and share from previous state
auto shrinked_pre_state = pre_state->Slice(0, num_instances);
state_pre.ShareDataWith(shrinked_pre_state);
}
}
void RNNAlgorithm::LinkInitialState(const rnn::StateAttr& state) {
// all the step_inputs' metas should be the same, just randomly select one
// and get the dyseq meta.
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
auto& scope = cache_.GetScope(0);
auto& state_pre = *cache_.GetTensor(scope, state.pre_var);
auto* pre_state = cache_.GetTensor(*cache_.scope, state.boot_var);
pre_state->mutable_data<float>(platform::CPUPlace());
// allocate state
state_pre.Resize(pre_state->dims());
state_pre.mutable_data<value_type>(platform::CPUPlace());
detail::ReorderInitialState(some_meta, *pre_state, &state_pre,
pre_state->place());
}
void RNNAlgorithm::ExportInitialStateGradient(const rnn::StateAttr& state) {
// all the step_inputs' metas should be the same, just randomly select one
// and get the dyseq meta.
const auto& some_meta = dy_seq_metas_[arg_.inlinks.front()];
auto& scope = cache_.GetScope(0);
auto& state_pre = *cache_.GetTensor(scope, state.pre_var);
auto& pre_state = *cache_.GetTensor(*cache_.scope, state.boot_var);
pre_state.Resize(state_pre.dims());
detail::RestoreInitialState(some_meta, state_pre, &pre_state,
pre_state.place());
}
void RNNAlgorithm::ArgCache::Init(const rnn::ArgumentName& name,
const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope,
platform::DeviceContext const* dev_ctx,
rnn::Argument* arg) {
this->scope = &scope;
InitArgument(name, op, arg);
CacheScopes(scope, *arg);
CacheInlinks(scope, arg->inlinks);
CacheOutlinks(scope, arg->outlinks);
this->dev_ctx = dev_ctx;
}
void RNNAlgorithm::ArgCache::InitArgument(const rnn::ArgumentName& name,
const OperatorBase& op,
rnn::Argument* arg) {
rnn::InitArgument(name, arg, op, false /*is_grad*/);
}
void RNNAlgorithm::ArgCache::CacheScopes(const Scope& scope,
const rnn::Argument& arg) {
auto scopes_var = scope.FindVar(arg.step_scopes);
PADDLE_ENFORCE(scopes_var != nullptr,
"the step_scopes output argument [%s] should be created first "
"by framework.",
arg.step_scopes);
this->scopes = scopes_var->GetMutable<std::vector<Scope*>>();
}
void RNNAlgorithm::ArgCache::CacheInlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inputs[name] = var;
}
}
void RNNAlgorithm::ArgCache::CacheOutlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
outputs[name] = var;
}
}
Variable* RNNAlgorithm::ArgCache::GetVariable(const Scope& scope,
const std::string& name) {
auto* var = scope.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name);
return var;
}
LoDTensor* RNNAlgorithm::ArgCache::GetTensor(const framework::Scope& scope,
const std::string& name) {
auto* var = GetVariable(scope, name);
return var->GetMutable<LoDTensor>();
}
const std::array<rnn::ArgumentName, 2> RNNAlgorithm::kArgNames{
{rnn::ArgumentName{"step_unit", "step_scopes", "inputs", "outputs",
"states", "ex_states", "initial_states"},
rnn::ArgumentName{"step_unit", "step_scopes@GRAD", "outputs@GRAD",
"inputs@GRAD", "states", "ex_states",
"initial_states@GRAD"}}};
void DynamicRecurrentOp::Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const {
rnn.Run<RNNAlgorithm::ComputeMode::kForward>(
scope, *dynamic_cast<const OperatorBase*>(this), dev_ctx);
}
void DynamicRecurrentGradientOp::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
rnn.Run<RNNAlgorithm::ComputeMode::kBackward>(
scope, *dynamic_cast<const OperatorBase*>(this), dev_ctx);
}
class DynamicRecurrentOpProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name =
RNNAlgorithm::kArgNames[RNNAlgorithm::ComputeMode::kForward];
// inputs and outputs stored in proto
AddInput(name.inlinks,
"The inputs that need to be segmented for each step.")
.AsDuplicable();
AddInput(name.initial_states, "Variables to initialize the states.")
.AsDuplicable();
AddOutput(name.outlinks,
"The outputs that need to be concatenated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.ex_states, "names of ex_states");
AddAttr<std::vector<std::string>>(name.states, "names of states");
AddComment(R"DOC(
Dynamic Recurrent Operator.
This is a RNN operator for varience-length sequences.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(dynamic_recurrent, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker,
dynamic_recurrent_grad,
paddle::operators::DynamicRecurrentGradientOp);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
class RNNAlgorithm {
public:
enum ComputeMode { kForward = 0, kBackward = 1 };
static const std::array<rnn::ArgumentName, 2> kArgNames;
using value_type = float;
/*
* Different `Run` method for forward and backward, `_` is just for template
* specifialization.
*/
template <ComputeMode _>
void Run(const framework::Scope& scope, const framework::OperatorBase& op,
const platform::DeviceContext& dev_ctx);
/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void SplitInputs();
/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void CreateScopes();
/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void WriteStepInputs();
/*
* Write output of each step to the corresponding TensorArray.
*/
void WriteStepOutputs();
/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void InitStates();
/*
* Create state variables for each time step.
*/
void CreateState(const rnn::StateAttr& state, size_t step);
/*
* Link pre-state variable in current scope to the state variable in the
* previous time step (scope) by reference.
*/
void LinkState(const rnn::StateAttr& state, size_t step);
/*
* Link the pre-state of the first time step to the `boot-state` in parent's
* scope.
*/
void LinkInitialState(const rnn::StateAttr& state);
/*
* Copy the gradient from `pre-state` in the first step-scope to the
* `boot-state` in parent's scope.
*/
void ExportInitialStateGradient(const rnn::StateAttr& state);
/*
* Calculate time steps.
*/
void RunSteps();
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void ConcatOutputs();
void SetComputeMode(ComputeMode mode) { mode_ = mode; }
bool IsForward() const { return mode_ == ComputeMode::kForward; }
bool IsBackward() const { return mode_ == ComputeMode::kBackward; }
/*
* set a step unit that is created according to a RecurrentOp's step unit.
*/
void SetStepUnit(std::unique_ptr<framework::OperatorBase> step_unit) {
PADDLE_ENFORCE_NOT_NULL(step_unit);
step_unit_ = std::move(step_unit);
}
const framework::OperatorBase& GetStepUnit() const { return *step_unit_; }
const framework::TensorArray& state(const std::string& name) const {
auto it = states_.find(name);
PADDLE_ENFORCE(it != states_.end());
return it->second;
}
const framework::TensorArray& step_input(const std::string& name) const {
auto it = step_inputs_.find(name);
PADDLE_ENFORCE(it != step_inputs_.end());
return it->second;
}
const framework::TensorArray& step_output(const std::string& name) const {
auto it = step_outputs_.find(name);
PADDLE_ENFORCE(it != step_outputs_.end());
return it->second;
}
protected:
struct ArgCache {
framework::Scope const* scope;
std::vector<framework::Scope*>* scopes;
std::map<std::string, framework::Variable*> inputs;
std::map<std::string, framework::Variable*> outputs;
platform::DeviceContext const* dev_ctx;
size_t num_steps{0};
void Init(const rnn::ArgumentName& name, const framework::OperatorBase& op,
const framework::Scope& scope,
platform::DeviceContext const* dev_ctx, rnn::Argument* arg);
framework::Scope& GetScope(size_t index) {
PADDLE_ENFORCE_LT(index, num_steps);
return *scopes->at(index);
}
framework::LoDTensor* GetTensor(const framework::Scope& scope,
const std::string& name);
private:
void InitArgument(const rnn::ArgumentName& name,
const framework::OperatorBase& op, rnn::Argument* arg);
void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg);
void CacheInlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
void CacheOutlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
framework::Variable* GetVariable(const framework::Scope& scope,
const std::string& name);
};
private:
std::unique_ptr<framework::OperatorBase> step_unit_;
std::map<std::string, framework::TensorArray> states_;
std::map<std::string, framework::TensorArray> step_inputs_;
std::map<std::string, framework::TensorArray> step_outputs_;
std::map<std::string, std::vector<framework::DySeqMeta>> dy_seq_metas_;
rnn::Argument arg_;
ArgCache cache_;
ComputeMode mode_{ComputeMode::kForward};
#ifdef PADDLE_WITH_TESTING
// test forward
friend class RNNAlgorithmTestHelper;
FRIEND_TEST(RNNAlgorithmTestHelper, SplitInputs);
FRIEND_TEST(RNNAlgorithmTestHelper, CreateCache);
FRIEND_TEST(RNNAlgorithmTestHelper, CreateScopes);
FRIEND_TEST(RNNAlgorithmTestHelper, WriteStepInputs);
FRIEND_TEST(RNNAlgorithmTestHelper, WriteStepOutputs);
FRIEND_TEST(RNNAlgorithmTestHelper, InitStates);
FRIEND_TEST(RNNAlgorithmTestHelper, ConcatOutputs);
// TODO(superjom) test backward
#endif
};
class DynamicRecurrentOp : public framework::OperatorBase {
public:
DynamicRecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
DynamicRecurrentOp(const DynamicRecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
mutable RNNAlgorithm rnn;
};
class DynamicRecurrentGradientOp : public framework::OperatorBase {
public:
DynamicRecurrentGradientOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
DynamicRecurrentGradientOp(const DynamicRecurrentGradientOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
mutable RNNAlgorithm rnn;
};
} // namespace operators
} // namespace paddle
#include "paddle/operators/dynamic_recurrent_op.h"
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
void OpDescNewVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
// create a LoD tensor in scope with specific dims
LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
const platform::Place& place) {
auto* var = scope.Var(name);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(place);
return tensor;
}
class RNNAlgorithmTestHelper : public ::testing::Test {
protected:
const rnn::ArgumentName argname = RNNAlgorithm::kArgNames[0];
virtual void SetUp() override {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn);
InitCacheManually();
InitStepNet();
}
framework::OpDesc CreateOpDesc() {
// create op
paddle::framework::OpDesc op_desc;
op_desc.set_type("dynamic_recurrent");
OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.initial_states, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs());
OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs());
// set pre-states
auto pre_memories = op_desc.mutable_attrs()->Add();
pre_memories->set_name(argname.ex_states);
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories_item = pre_memories->add_strings();
*pre_memories_item = "mem@pre";
// set states
auto memories = op_desc.mutable_attrs()->Add();
memories->set_name(argname.states);
memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories_item = memories->add_strings();
*memories_item = "mem";
return op_desc;
}
void CreateGlobalVariables() {
platform::CPUPlace place;
scope.Var("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework::LoD in0_lod(1);
for (int x : std::vector<int>{0, 4, 7, 9, 10}) {
in0_lod[0].push_back(x);
}
in0->set_lod(in0_lod);
in0->Resize(framework::make_ddim({10, 8}));
// set the content, each sentence content is seqid.batchid
// the seqid starts from 0
int start = 0;
for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) {
for (size_t batchid = 0;
batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) {
float v = seqid + batchid * 0.1;
for (size_t dim = 0; dim < 8; dim++) {
in0->data<float>()[start * 8 + dim] = v;
}
start++;
}
}
}
void InitCacheManually() {
dop->cache_.Init(RNNAlgorithm::kArgNames[0], *op, scope, &device_context,
&dop->arg_);
}
void InitStepNet() {
std::unique_ptr<framework::OperatorBase> stepnet{new NetOp};
dynamic_cast<NetOp*>(stepnet.get())
->AppendOp(std::unique_ptr<TestOp>(new TestOp(
"test", {{"inputs", {"in0"}}, {"initial_states", {"boot_mem"}}},
{{"outputs", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepUnit(std::move(stepnet));
}
protected:
RNNAlgorithm* dop;
std::unique_ptr<framework::OperatorBase> op;
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
};
TEST_F(RNNAlgorithmTestHelper, CreateCache) {
const rnn::Argument& arg = dop->arg_;
ASSERT_EQ(arg.inlinks.size(), 1UL);
ASSERT_EQ(arg.outlinks.size(), 1UL);
}
TEST_F(RNNAlgorithmTestHelper, SplitInputs) {
dop->SplitInputs();
auto& in0_ta = dop->step_inputs_["in0"];
ASSERT_EQ(in0_ta.size(), 4UL);
const auto& batch0 = in0_ta.Read(0);
const auto& batch1 = in0_ta.Read(1);
const auto& batch2 = in0_ta.Read(2);
const auto& batch3 = in0_ta.Read(3);
EXPECT_EQ(batch0.dims()[0], 4);
EXPECT_EQ(batch1.dims()[0], 3);
EXPECT_EQ(batch2.dims()[0], 2);
EXPECT_EQ(batch3.dims()[0], 1);
}
TEST_F(RNNAlgorithmTestHelper, CreateScopes) {
dop->SplitInputs();
dop->CreateScopes();
ASSERT_EQ(dop->cache_.num_steps, 4UL);
ASSERT_EQ(dop->cache_.scopes->size(), 4UL);
}
TEST_F(RNNAlgorithmTestHelper, WriteStepInputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"in0"})) {
ASSERT_TRUE(scope.FindVar(name) != nullptr);
}
}
}
TEST_F(RNNAlgorithmTestHelper, WriteStepOutputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"out0"})) {
ASSERT_TRUE(scope.FindVar(name));
}
}
}
TEST_F(RNNAlgorithmTestHelper, ConcatOutputs) {
// Let's leave this test to python unittest.
}
TEST_F(RNNAlgorithmTestHelper, InitStates) {
dop->SetComputeMode(RNNAlgorithm::ComputeMode::kForward);
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
dop->InitStates();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
auto state = scope.FindVar("mem");
ASSERT_TRUE(state != nullptr);
auto* pre_state = scope.FindVar("mem@pre");
ASSERT_TRUE(pre_state != nullptr);
auto* boot_state = scope.FindVar("boot_mem");
ASSERT_TRUE(boot_state != nullptr);
}
}
} // operators
} // namespace paddle
......@@ -125,7 +125,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
framework::CopyFrom(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
switch (dims) {
REP_EXPAND_GRAD_TEMPLATE(72)
......
......@@ -47,7 +47,7 @@ class FeedOp : public framework::OperatorBase {
auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
out_item->CopyFrom(feed_item, dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(feed_item, dev_ctx.GetPlace(), dev_ctx, out_item);
out_item->set_lod(feed_item.lod());
}
};
......
......@@ -51,7 +51,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait();
dst_item.set_lod(src_item.lod());
......
......@@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -63,7 +63,7 @@ class FillConstantBatchSizeLikeOpMaker
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......
......@@ -34,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto data_type = static_cast<framework::DataType>(Attr<int>("data_type"));
auto data_type = static_cast<framework::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
......@@ -55,7 +55,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
FillConstantOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......
......@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -88,7 +88,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
"Random seed of generator."
"0 means use system wide seed.")
.SetDefault(0);
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::DataType::FP32);
......
......@@ -28,6 +28,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };
template <typename Place, typename T>
......@@ -226,7 +230,7 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
// backward for bias
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_b = EigenVector<T>::Flatten(*bias_grad);
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
}
......
......@@ -70,11 +70,18 @@ input value and Y as the target value. Huber loss can evaluate the fitness of
X to Y. Different from MSE loss, Huber loss is more robust for outliers. The
shape of X and Y are [batch_size, 1]. The equation is:
L_{\delta}(y, f(x)) =
$$
Out_{\delta}(X, Y)_i =
\begin{cases}
0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\
\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise
0.5 * (Y_i - X_i)^2,
\quad |Y_i - X_i| \leq \delta \\
\delta * (|Y_i - X_i| - 0.5 * \delta),
\quad otherwise
\end{cases}
$$
In the above equation, $Out_\delta(X, Y)_i$, $X_i$ and $Y_i$ represent the ith
element of Out, X and Y.
)DOC");
}
......
......@@ -32,19 +32,19 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
"operator. See more details in the operator's comments.");
AddInput("Label",
"(LoDTensor, default LoDTensor<int>) A LoDTensor with shape "
"(LoDTensor, default LoDTensor<int64_t>) A LoDTensor with shape "
"[N x 1], where N is the total element number in a mini-batch. "
"The ground truth.");
AddOutput(
"Alpha",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
"The forward vectors for the entire batch. Denote it as \f$\alpha\f$. "
"\f$\alpha$\f is a memo table used to calculate the normalization "
"factor in CRF. \f$\alpha[k, v]$\f stores the unnormalized "
"The forward vectors for the entire batch. Denote it as $\alpha$. "
"$\alpha$ is a memo table used to calculate the normalization "
"factor in CRF. $\alpha[k, v]$ stores the unnormalized "
"probabilites of all possible unfinished sequences of tags that end at "
"position \f$k$\f with tag \f$v$\f. For each \f$k$\f, "
"\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for "
"each tag value \f$v$\f. This vector is called a forward vecotr and "
"position $k$ with tag $v$. For each $k$, "
"$\alpha[k, v]$ is a vector of length $D$ with a component for "
"each tag value $v$. This vector is called a forward vecotr and "
"will also be used in backward computations.")
.AsIntermediate();
AddOutput(
......@@ -73,9 +73,9 @@ LinearChainCRF Operator.
Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these
variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and
\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs.
variables. CRF learns the conditional probability $P(Y|X)$, where
$X = (x_1, x_2, ... , x_n)$ are structured inputs and
$Y = (y_1, y_2, ... , y_n)$ are labels for the inputs.
Linear chain CRF is a special case of CRF that is useful for sequence labeling
task. Sequence labeling tasks do not assume a lot of conditional
......@@ -88,21 +88,22 @@ CRF. Please refer to http://www.cs.columbia.edu/~mcollins/fb.pdf and
http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for details.
Equation:
1. Denote Input(Emission) to this operator as \f$x\f$ here.
1. Denote Input(Emission) to this operator as $x$ here.
2. The first D values of Input(Transition) to this operator are for starting
weights, denoted as \f$a\f$ here.
weights, denoted as $a$ here.
3. The next D values of Input(Transition) of this operator are for ending
weights, denoted as \f$b\f$ here.
weights, denoted as $b$ here.
4. The remaning values of Input(Transition) are for transition weights,
denoted as \f$w\f$ here.
5. Denote Input(Label) as \f$s\f$ here.
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
\f$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L}
+ \sum_{l=1}^L x_{s_l}
+ \sum_{l=2}^L w_{s_{l-1},s_l})\f$
where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over
all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight
denoted as $w$ here.
5. Denote Input(Label) as $s$ here.
The probability of a sequence $s$ of length $L$ is defined as:
$$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L}
+ \sum_{l=1}^L x_{s_l}
+ \sum_{l=2}^L w_{s_{l-1},s_l})$$
where $Z$ is a normalization value so that the sum of $P(s)$ over
all possible sequences is 1, and $x$ is the emission feature weight
to the linear chain CRF.
Finally, the linear chain CRF operator outputs the logarithm of the conditional
......
......@@ -195,7 +195,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
auto copyLoDTensor = [](const platform::DeviceContext& ctx,
const LoDTensor& src, LoDTensor* dst) {
dst->mutable_data<T>(src.dims(), platform::CPUPlace());
dst->CopyFrom(src, platform::CPUPlace(), ctx);
framework::CopyFrom(src, platform::CPUPlace(), ctx, dst);
};
copyLoDTensor(ctx, emission_weights_src, emission_weights_dst);
......@@ -203,8 +203,8 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
transition_weights_dst->mutable_data<T>(transition_weights_src.dims(),
platform::CPUPlace());
transition_weights_dst->CopyFrom(transition_weights_src,
platform::CPUPlace(), ctx);
framework::CopyFrom(transition_weights_src, platform::CPUPlace(), ctx,
transition_weights_dst);
}
void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx,
......@@ -219,7 +219,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src,
Tensor* dst) {
dst->mutable_data<T>(platform::GPUPlace());
dst->CopyFrom(src, platform::GPUPlace(), ctx);
framework::CopyFrom(src, platform::GPUPlace(), ctx, dst);
};
copyTensor(ctx, emission_exps_src, emission_exps_dst);
copyTensor(ctx, transition_exps_src, transition_exps_dst);
......@@ -410,12 +410,12 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
// Copy the inputs from GPU memory to CPU memory when this operators runs on
// GPU device.
label_dst->mutable_data<T>(label_src.dims(), platform::CPUPlace());
label_dst->CopyFrom(label_src, platform::CPUPlace(), ctx);
framework::CopyFrom(label_src, platform::CPUPlace(), ctx, label_dst);
auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src,
Tensor* dst) {
dst->mutable_data<T>(src.dims(), platform::CPUPlace());
dst->CopyFrom(src, platform::CPUPlace(), ctx);
framework::CopyFrom(src, platform::CPUPlace(), ctx, dst);
};
copyTensor(ctx, emission_exps_src, emission_exps_dst);
copyTensor(ctx, transition_exps_src, transition_exps_dst);
......@@ -434,7 +434,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
Tensor* dst) {
if (src && dst) {
dst->mutable_data<T>(platform::GPUPlace());
dst->CopyFrom(*src, platform::GPUPlace(), ctx);
framework::CopyFrom(*src, platform::GPUPlace(), ctx, dst);
}
};
copyTensor(ctx, emission_grad_src, emission_grad_dst);
......
......@@ -105,7 +105,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
tensor->CopyFrom(cpu_tensor, place, dev_ctx);
CopyFrom(cpu_tensor, place, dev_ctx, tensor);
}
}
};
......
......@@ -33,7 +33,8 @@ class LoDResetKernel : public framework::OpKernel<T> {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
lod_cpu.CopyFrom(*lod_t, platform::CPUPlace(), ctx.device_context());
framework::CopyFrom(*lod_t, platform::CPUPlace(), ctx.device_context(),
&lod_cpu);
lod = lod_cpu.data<int>();
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
......
......@@ -81,11 +81,11 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
continue;
}
// out[i][offset: offset+len] = x[each_range.begin: each_range.end]
out[i]
.Slice(static_cast<int>(offset), static_cast<int>(offset + len))
.CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx);
auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len));
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice);
offset += len;
}
}
......
......@@ -149,7 +149,7 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size);
out_t_sub.CopyFrom(w_sub, context.GetPlace(), context);
framework::CopyFrom(w_sub, context.GetPlace(), context, &out_t_sub);
}
}
if (down_pad > 0) { // add down pad
......@@ -179,7 +179,7 @@ class ContextProjectFunctor {
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
out_t_sub.CopyFrom(w_sub, context.GetPlace(), context);
framework::CopyFrom(w_sub, context.GetPlace(), context, &out_t_sub);
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......
......@@ -74,7 +74,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
CopyFrom(input_tmp, *place, *context, &input);
}
output_cfo.mutable_data<float>(
{1, filter_size, filter_size, output_height, output_width}, *place);
......@@ -99,7 +99,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>();
} else {
output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context);
CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context, &output_tmp);
out_cfo_ptr = output_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
......@@ -110,7 +110,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
out_ocf_ptr = output_ocf.data<float>();
} else {
output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context);
CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp);
out_ocf_ptr = output_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
......@@ -130,7 +130,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
CopyFrom(input_tmp, *place, *context, &input);
}
col2im(*context, output_cfo, dilation, stride, padding, &input);
......@@ -139,7 +139,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
CopyFrom(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
......@@ -151,7 +151,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
CopyFrom(input_tmp, *place, *context, &input);
}
col2im_ocf(*context, output_ocf, dilation, stride, padding, &input);
......@@ -159,7 +159,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
CopyFrom(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
......
......@@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
template struct RowwiseAdd<platform::GPUPlace, float>;
template struct RowwiseAdd<platform::GPUPlace, double>;
template struct ColwiseSum<platform::GPUPlace, float>;
template struct ColwiseSum<platform::GPUPlace, double>;
// template struct ColwiseSum<platform::GPUPlace, double>;
// The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void ColwiseSum<platform::GPUPlace, double>::operator()(
const platform::DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), size);
framework::Tensor one;
one.mutable_data<double>({in_dims[0]}, context.GetPlace());
SetConstant<platform::GPUPlace, double> set;
set(context, &one, static_cast<double>(1.0));
gemv<platform::GPUPlace, double>(context, true, static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]), 1.0,
input.data<double>(), one.data<double>(),
0.0, vector->data<double>());
}
} // namespace math
} // namespace operators
......
......@@ -49,6 +49,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......
......@@ -16,15 +16,15 @@ TEST(math_function, notrans_mul_trans) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom(input1, *gpu_place, context);
input2_gpu.CopyFrom(input1, *gpu_place, context);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input2_gpu);
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
out.CopyFrom(out_gpu, *cpu_place, context);
paddle::framework::CopyFrom(out_gpu, *cpu_place, context, &out);
float* out_ptr = out.data<float>();
context.Wait();
......@@ -50,15 +50,15 @@ TEST(math_function, trans_mul_notrans) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom(input1, *gpu_place, context);
input2_gpu.CopyFrom(input1, *gpu_place, context);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input2_gpu);
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
out.CopyFrom(out_gpu, *cpu_place, context);
paddle::framework::CopyFrom(out_gpu, *cpu_place, context, &out);
float* out_ptr = out.data<float>();
context.Wait();
......@@ -99,9 +99,9 @@ TEST(math_function, gemm_notrans_cublas) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom(input1, *gpu_place, context);
input2_gpu.CopyFrom(input2, *gpu_place, context);
input3_gpu.CopyFrom(input3, *gpu_place, context);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
paddle::framework::CopyFrom(input2, *gpu_place, context, &input2_gpu);
paddle::framework::CopyFrom(input3, *gpu_place, context, &input3_gpu);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
......@@ -109,7 +109,7 @@ TEST(math_function, gemm_notrans_cublas) {
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
input3.CopyFrom(input3_gpu, *cpu_place, context);
paddle::framework::CopyFrom(input3_gpu, *cpu_place, context, &input3);
// numpy code:
// a = np.arange(6).reshape(2, 3)
......@@ -154,9 +154,9 @@ TEST(math_function, gemm_trans_cublas) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom(input1, *gpu_place, context);
input2_gpu.CopyFrom(input2, *gpu_place, context);
input3_gpu.CopyFrom(input3, *gpu_place, context);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
paddle::framework::CopyFrom(input2, *gpu_place, context, &input2_gpu);
paddle::framework::CopyFrom(input3, *gpu_place, context, &input3_gpu);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
......@@ -164,7 +164,7 @@ TEST(math_function, gemm_trans_cublas) {
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
input3.CopyFrom(input3_gpu, *cpu_place, context);
paddle::framework::CopyFrom(input3_gpu, *cpu_place, context, &input3);
context.Wait();
EXPECT_EQ(input3_ptr[0], 0);
......@@ -205,14 +205,15 @@ void GemvTest(int m, int n, bool trans) {
}
paddle::platform::CUDADeviceContext context(*gpu_place);
g_mat_a.CopyFrom(mat_a, *gpu_place, context);
g_vec_b.CopyFrom(vec_b, *gpu_place, context);
paddle::framework::CopyFrom(mat_a, *gpu_place, context, &g_mat_a);
paddle::framework::CopyFrom(vec_b, *gpu_place, context, &g_vec_b);
paddle::operators::math::gemv<paddle::platform::GPUPlace, T>(
context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a,
g_data_b, 0., g_data_c);
vec_c.CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context);
paddle::framework::CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context,
&vec_c);
if (!trans) {
for (int i = 0; i < m; ++i) {
......
......@@ -145,6 +145,8 @@ struct SelectedRowsAddTo<platform::CPUPlace, T> {
template struct SelectedRowsAddTo<platform::CPUPlace, float>;
template struct SelectedRowsAddTo<platform::CPUPlace, double>;
template struct SelectedRowsAddTo<platform::CPUPlace, int>;
template struct SelectedRowsAddTo<platform::CPUPlace, int64_t>;
template <typename T>
struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
......@@ -175,6 +177,8 @@ struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
template struct SelectedRowsAddToTensor<platform::CPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, double>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, int>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, int64_t>;
} // namespace math
} // namespace operators
......
......@@ -173,6 +173,8 @@ struct SelectedRowsAddTo<platform::GPUPlace, T> {
template struct SelectedRowsAddTo<platform::GPUPlace, float>;
template struct SelectedRowsAddTo<platform::GPUPlace, double>;
template struct SelectedRowsAddTo<platform::GPUPlace, int>;
template struct SelectedRowsAddTo<platform::GPUPlace, int64_t>;
namespace {
template <typename T, int block_size>
......@@ -223,6 +225,8 @@ struct SelectedRowsAddToTensor<platform::GPUPlace, T> {
template struct SelectedRowsAddToTensor<platform::GPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, double>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, int>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, int64_t>;
} // namespace math
} // namespace operators
......
......@@ -67,7 +67,7 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
out_cpu.CopyFrom(*out_value, cpu_place, ctx);
CopyFrom(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
......@@ -94,7 +94,7 @@ TEST(selected_rows_functor, gpu_add) {
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
Tensor tensor2_cpu;
tensor2_cpu.CopyFrom(*tensor2, cpu_place, ctx);
CopyFrom(*tensor2, cpu_place, ctx, &tensor2_cpu);
ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>();
......@@ -167,7 +167,7 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
out_cpu.CopyFrom(*out_value, cpu_place, ctx);
CopyFrom(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
......@@ -191,7 +191,7 @@ TEST(selected_rows_functor, gpu_add_to) {
add_to_tensor_functor(ctx, *output, tensor1.get());
Tensor tensor1_cpu;
tensor1_cpu.CopyFrom(*tensor1, cpu_place, ctx);
CopyFrom(*tensor1, cpu_place, ctx, &tensor1_cpu);
ctx.Wait();
auto* tensor1_cpu_data = tensor1_cpu.data<float>();
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......
......@@ -82,7 +82,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
CopyFrom(input_tmp, *place, *context, &input);
}
output.mutable_data<float>({1, filter_size, filter_size, filter_size,
output_depth, output_height, output_width},
......@@ -96,7 +96,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output.data<float>();
} else {
output_tmp.CopyFrom(output, paddle::platform::CPUPlace(), *context);
CopyFrom(output, paddle::platform::CPUPlace(), *context, &output_tmp);
out_cfo_ptr = output_tmp.data<float>();
}
......@@ -110,7 +110,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
CopyFrom(input_tmp, *place, *context, &input);
}
paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
......@@ -120,7 +120,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
CopyFrom(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>();
}
......
......@@ -45,7 +45,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA
cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx);
framework::CopyFrom(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get());
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
......@@ -99,8 +99,9 @@ class MergeLoDTensorOp : public framework::OperatorBase {
if (len == 0) {
continue;
}
out->Slice(out_offset, out_offset + len)
.CopyFrom(input->Slice(start_offset, end_offset), place, dev_ctx);
auto slice = out->Slice(out_offset, out_offset + len);
framework::CopyFrom(input->Slice(start_offset, end_offset), place,
dev_ctx, &slice);
out_offset += len;
(*in_idx) += 1;
}
......
......@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
CopyFrom(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
Place place = boost::get<Place>(ctx.GetPlace());
......@@ -68,7 +68,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
CopyFrom(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
......
......@@ -49,7 +49,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......
......@@ -97,7 +97,7 @@ class NCCLTester : public ::testing::Test {
send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id);
send_tensor->CopyFromVector<T>(send_vector, *ctx);
paddle::framework::CopyFromVector<T>(send_vector, *ctx, send_tensor);
ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
}
......
......@@ -284,7 +284,8 @@ class RecurrentOp : public RecurrentBase {
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed
// early.
dst_out.CopyFrom(src_tensor, dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(src_tensor, dev_ctx.GetPlace(), dev_ctx,
&dst_out);
});
scopes.Next();
......@@ -365,7 +366,8 @@ class RecurrentGradOp : public RecurrentBase {
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
cur_grad_tensor->CopyFrom(ex_tensor, dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(ex_tensor, dev_ctx.GetPlace(), dev_ctx,
cur_grad_tensor);
}
}
......@@ -401,7 +403,7 @@ class RecurrentGradOp : public RecurrentBase {
auto &inside_tensor = cur_scope.FindVar(inside_grad_name)
->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
......@@ -438,7 +440,7 @@ class RecurrentGradOp : public RecurrentBase {
}
auto dst = outside->Slice(seq_offset, seq_offset + 1);
dst.CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, &dst);
});
VLOG(5) << "Link outside gradient finished ";
......@@ -451,7 +453,7 @@ class RecurrentGradOp : public RecurrentBase {
framework::LoDTensor *outside) {
outside->Resize(inside.dims());
outside->mutable_data(dev_ctx.GetPlace(), inside.type());
outside->CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(inside, dev_ctx.GetPlace(), dev_ctx, outside);
});
VLOG(5) << "Link initialize state gradient finished ";
}
......
......@@ -28,7 +28,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims();
out->mutable_data<T>(ctx.GetPlace());
out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context());
framework::CopyFrom(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims);
}
};
......@@ -42,7 +42,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims();
d_x->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context());
framework::CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
d_x->Resize(in_dims);
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
namespace rnn {
namespace f = paddle::framework;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
// global inputs
auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
inlinks[i]);
LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims();
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
"all the inputs be the same length");
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->Var(inlinks[i])->GetMutable<Tensor>();
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice(j, j + 1);
step_input->Resize(step_dims);
}
}
}
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, const platform::DeviceContext& ctx) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>();
auto* step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice(j, j + 1))
.CopyFrom(*step_output, platform::CPUPlace(), ctx);
}
}
}
void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::StateAttr>& memories,
const size_t step_id, const int offset) {
PADDLE_ENFORCE_LT(step_id, scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
step_id, scopes.size());
PADDLE_ENFORCE_GE(static_cast<int>(step_id) + offset, 0,
"offset [%d] must be large than -[%d]", offset, step_id);
PADDLE_ENFORCE_LT(
step_id + offset, scopes.size(),
"offset [%d] is out of range, it must be less than (%d - %d)", offset,
scopes.size(), step_id);
auto* scope = scopes[step_id];
auto* linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto* mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto* linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
mem->Resize(linked_mem->dims());
mem->ShareDataWith(*linked_mem);
}
}
void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op, bool is_grad) {
arg->step_scopes =
is_grad ? op.Input(name.step_scopes) : op.Output(name.step_scopes);
arg->inlinks = op.Inputs(name.inlinks);
arg->outlinks = op.Outputs(name.outlinks);
auto& boot_memories = is_grad ? op.Outputs(name.initial_states)
: op.Inputs(name.initial_states);
// attributes
auto& memories = op.Attr<std::vector<std::string>>(name.states);
auto& pre_memories = op.Attr<std::vector<std::string>>(name.ex_states);
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of states, initial_states don't match:%d,%d",
memories.size(), boot_memories.size());
PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
"the size of ex_states, initial_states don't match:%d,%d",
pre_memories.size(), boot_memories.size());
PADDLE_ENFORCE(memories.size() > 0, "more than 1 states should be set");
for (size_t i = 0; i < memories.size(); ++i) {
rnn::StateAttr mem_attr;
mem_attr.var = memories[i];
mem_attr.pre_var = pre_memories[i];
mem_attr.boot_var = boot_memories[i];
(arg->states).push_back(mem_attr);
}
}
} // namespace rnn
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
namespace rnn {
using Scope = framework::Scope;
/**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
*
* Memory attributes cached by this op, dims will be infered from
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
*/
struct StateAttr {
// name of current state variable
std::string var;
// name of previous step's state variable
std::string pre_var;
// name of the variables to init this memory (same role of `boot_layer` in
// PaddlePaddle), which is store in father's scope.
std::string boot_var;
};
struct Argument {
std::string step_net;
std::string step_scopes;
std::vector<std::string> inlinks;
std::vector<std::string> outlinks;
std::vector<rnn::StateAttr> states;
};
struct ArgumentName {
std::string step_net;
std::string step_scopes;
std::string inlinks;
std::string outlinks;
std::string states; // the memory name
std::string ex_states; // the previous memory name
std::string initial_states; // the boot memory name
};
/**
* Prepare inputs for each step net.
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, const platform::DeviceContext& ctx);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<StateAttr>& memories, const size_t step_id,
const int offset);
void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op, bool is_grad = false);
} // namespace rnn
} // namespace operators
} // namespace paddle
......@@ -62,7 +62,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddOutput("Out", "");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......@@ -95,7 +95,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto &in_var_tensor = in_var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(in_var_tensor.type());
attrs["dtype"] = framework::ToDataType(in_var_tensor.type());
attrs["shape"] = framework::vectorize2int(in_var_tensor.dims());
attrs["value"] = 0.0f;
......@@ -121,7 +121,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddInput("X", "");
AddInput("Out", "");
AddOutput(framework::GradVarName("X"), "");
AddAttr<int>("data_type",
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/roi_pool_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kROISize = 5;
class ROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
"Input(ROIs) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ROIPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Argmax"),
"Output(Argmax) of ROIPoolOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must greater than 0");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims);
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class ROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), "
"the input of ROIPoolOp. "
"The format of input tensor is NCHW. Where N is batch size, "
"C is the number of input channels, "
"H is the height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(Tensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]. "
"Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddOutput("Out",
"(Tensor), "
"The output of ROIPoolOp is a 4-D tensor with shape "
"(num_rois, channels, pooled_h, pooled_w).");
AddOutput("Argmax",
"(Tensor), "
"Argmaxes corresponding to indices in X used "
"for gradient computation. Only output "
"if arg “is_test” is false.").AsIntermediate();
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"The pooled output height.")
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"The pooled output width.")
.SetDefault(1);
AddComment(R"DOC(
ROIPool operator
ROI Pooling for Faster-RCNN. The link below is a further introduction:
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/roi_pool_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int kROISize = 5;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
}
template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads,
const int64_t* input_rois,
const T* output_grad,
const int64_t* argmax_data,
const int num_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(offset_input_grad + argmax,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}
template <typename Place, typename T>
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto* argmax = ctx.Output<Tensor>("Argmax");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
auto in_stride = framework::stride(in_dims);
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
size_t rois_num = rois->dims()[0];
if (rois_num== 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
GPUROIPoolForward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size,
in->data<T>(),
rois->data<int64_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
template <typename Place, typename T>
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad =
ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
size_t rois_num = rois->dims()[0];
int channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (x_grad) {
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIPoolBackward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size,
rois->data<int64_t>(),
out_grad->data<T>(),
argmax->data<int64_t>(),
rois_num,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
roi_pool,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
roi_pool_grad,
ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class CPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* argmax = ctx.Output<framework::Tensor>("Argmax");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims);
auto argmax_stride = framework::stride(argmax->dims());
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>();
const int64_t* rois_data = rois->data<int64_t>();
T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = rois_data[0];
PADDLE_ENFORCE_GE(roi_batch_id, 0);
PADDLE_ENFORCE_LT(roi_batch_id, batch_size);
rois_data += roi_stride[0];
}
rois_data = rois->data<int64_t>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = rois_data[0];
int roi_start_w = round(rois_data[1] * spatial_scale);
int roi_start_h = round(rois_data[2] * spatial_scale);
int roi_end_w = round(rois_data[3] * spatial_scale);
int roi_end_h = round(rois_data[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
const float bin_size_h =
static_cast<float>(roi_height) / static_cast<float>(pooled_height);
const float bin_size_w =
static_cast<float>(roi_width) / static_cast<float>(pooled_width);
const T* batch_data = input_data + roi_batch_id * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
// Compute pooling region for this output unit:
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
int hstart =
static_cast<int>(floor(static_cast<float>(ph) * bin_size_h));
int wstart =
static_cast<int>(floor(static_cast<float>(pw) * bin_size_w));
int hend =
static_cast<int>(ceil(static_cast<float>(ph + 1) * bin_size_h));
int wend =
static_cast<int>(ceil(static_cast<float>(pw + 1) * bin_size_w));
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
const int pool_index = ph * pooled_width + pw;
// Define an empty pooling region to be zero
bool is_empty = (hend <= hstart) || (wend <= wstart);
output_data[pool_index] =
is_empty ? 0 : -std::numeric_limits<T>::max();
argmax_data[pool_index] = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width + w;
if (batch_data[index] > output_data[pool_index]) {
output_data[pool_index] = batch_data[index];
argmax_data[pool_index] = index;
}
}
}
}
}
batch_data += in_stride[1];
output_data += out_stride[1];
argmax_data += argmax_stride[1];
}
// Increment ROI data pointer
rois_data += roi_stride[0];
}
return;
}
};
template <typename Place, typename T>
class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
if (x_grad) {
int channels = in->dims()[1];
auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims());
const int64_t* rois_data = rois->data<int64_t>();
int rois_num = rois->dims()[0];
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
size_t roi_offset = roi_stride[0];
size_t batch_offset = in_stride[0];
size_t channel_offset = in_stride[1];
const T* out_grad_data = out_grad->data<T>();
size_t pool_channel_offset = pooled_height * pooled_width;
const int64_t* argmax_data = argmax->data<int64_t>();
for (size_t n = 0; n < rois_num; ++n) {
size_t roi_batch_idx = rois_data[0];
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx;
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
size_t pool_index = ph * pooled_width + pw;
if (argmax_data[pool_index] >= 0) {
size_t index = static_cast<size_t>(argmax_data[pool_index]);
batch_grad_data[index] += out_grad_data[pool_index];
}
}
}
batch_grad_data += channel_offset;
out_grad_data += pool_channel_offset;
argmax_data += pool_channel_offset;
}
rois_data += roi_offset;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -26,7 +26,7 @@ using LoD = framework::LoD;
template <typename T>
inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
const int64_t* length_data) {
const int64_t* length_data) {
auto out_lod = in.lod();
size_t lod_offset = 0;
......@@ -34,7 +34,7 @@ inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
out_lod[0][0] = 0;
for (size_t i = 0; i < n; ++i) {
lod_offset += length_data[i];
out_lod[0][i+1] = lod_offset;
out_lod[0][i + 1] = lod_offset;
}
return out_lod;
}
......@@ -51,8 +51,7 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
auto lod = in->lod();
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same")
......@@ -67,23 +66,23 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
framework::CopyFrom(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu);
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
framework::CopyFrom(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu);
length_data = length_cpu.data<int64_t>();
}
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i)
"The offset[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(
lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1],
"The target tensor's length overflow.")
"The length[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1], "The target tensor's length overflow.")
}
out->mutable_data<T>(ctx.GetPlace());
......@@ -98,14 +97,12 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
size_t out_offset = 0;
for (size_t i = 0; i < n; ++i) {
Tensor in_t =
in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] +
length_data[i]));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
in_stride, in_t.dims(), out_stride,
out->data<T>() + out_offset);
Tensor in_t = in->Slice(
static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
in_t.dims(), out_stride, out->data<T>() + out_offset);
out_offset += length_data[i] * in_stride[0];
}
}
......@@ -130,11 +127,13 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
framework::CopyFrom(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu);
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
framework::CopyFrom(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu);
length_data = length_cpu.data<int64_t>();
}
......@@ -162,8 +161,8 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
}
}
}
......
......@@ -101,8 +101,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
} else {
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] < height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
......
......@@ -59,7 +59,7 @@ Then the ratio of the exponential of the given dimension and the sum of
exponential values of all the other dimensions is the output of the softmax
operator.
For each row `i` and each column `j` in input X, we have:
For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC");
......
......@@ -67,15 +67,15 @@ The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
$$Loss_j = \f$ -\text{Logit}_{Label_j} +
$$Loss_j = -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1, ..., K $\f$$
j = 1,..., K$$
2) Soft label (each sample can have a distribution over all classes)
$$Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i -
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K $\f$$
j = 1,...,K$$
)DOC");
}
......
......@@ -49,7 +49,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA
cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx);
framework::CopyFrom(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get());
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
......@@ -105,10 +105,11 @@ class SplitLoDTensorOp : public framework::OperatorBase {
continue;
}
// out[offset: offset+len] = x[each_range.begin: each_range.end]
out->Slice(static_cast<int>(offset), static_cast<int>(offset + len))
.CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx);
auto slice = out->Slice(static_cast<int>(offset),
static_cast<int>(offset + len));
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice);
offset += len;
}
}
......
......@@ -176,4 +176,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>,
ops::SumKernel<paddle::platform::CPUPlace, double>);
ops::SumKernel<paddle::platform::CPUPlace, double>,
ops::SumKernel<paddle::platform::CPUPlace, int>,
ops::SumKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -14,4 +14,6 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>,
ops::SumKernel<paddle::platform::GPUPlace, double>);
ops::SumKernel<paddle::platform::GPUPlace, double>,
ops::SumKernel<paddle::platform::GPUPlace, int>,
ops::SumKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -102,8 +102,8 @@ class SumKernel : public framework::OpKernel<T> {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
out_array[i].CopyFrom(in_array[i], in_array[i].place(),
context.device_context());
framework::CopyFrom(in_array[i], in_array[i].place(),
context.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
......
......@@ -38,7 +38,7 @@ class WriteToArrayOp : public ArrayOp {
out->resize(offset + 1);
}
auto *out_tensor = &out->at(offset);
out_tensor->CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx);
CopyFrom(x_tensor, dev_ctx.GetPlace(), dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod());
}
};
......@@ -116,7 +116,8 @@ class ReadFromArrayOp : public ArrayOp {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx);
PADDLE_ENFORCE_LT(offset, x_array.size());
out_tensor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx);
framework::CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx,
out_tensor);
out_tensor->set_lod(x_array[offset].lod());
}
};
......
......@@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -99,7 +99,7 @@ uniform distribution.
"Random seed used for generating samples. "
"0 means use a seed generated by the system.")
.SetDefault(0);
AddAttr<int>("data_type", "(int, default 5(FP32)) Output tensor data type")
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::DataType::FP32);
}
};
......
......@@ -180,7 +180,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
......
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
if(WITH_GPU)
cc_library(enforce SRCS enforce.cc DEPS nccl)
else()
cc_library(enforce SRCS enforce.cc)
endif()
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece enforce)
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog enforce)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce)
cc_library(place SRCS place.cc)
cc_library(place SRCS place.cc DEPS enforce)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE()
......
......@@ -31,6 +31,16 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512;
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
USE_CUDA_ATOMIC(Add, unsigned long long int);
CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert(sizeof(int64_t) == sizeof(long long int),
"long long should be int64");
return CudaAtomicAdd(reinterpret_cast<unsigned long long int*>(address),
static_cast<unsigned long long int>(val));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader nccl)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/enforce.h"
namespace paddle {
namespace platform {} // namespace platform
} // namespace paddle
......@@ -49,7 +49,6 @@ limitations under the License. */
namespace paddle {
namespace platform {
namespace {
#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
......@@ -60,7 +59,6 @@ inline std::string demangle(std::string name) {
#else
inline std::string demangle(std::string name) { return name; }
#endif
}
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
......
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc
DEPS pybind python backward proto_desc tensor_array paddle_memory executor prune
DEPS pybind python backward proto_desc paddle_memory executor prune
${GLOB_OP_LIB})
endif(WITH_PYTHON)
cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB} tensor_array)
cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB})
......@@ -202,9 +202,9 @@ void BindVarDsec(py::module &m) {
},
py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("set_dtype", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType)
.def("dtype", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
......
此差异已折叠。
......@@ -8,7 +8,7 @@ def _clone_var_in_block_(block, var):
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.data_type,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
......@@ -33,6 +33,9 @@ class Evaluator(object):
else:
self._main_program = g_main_program
def states(self):
return self._states
def _update_ops(self, *args, **kwargs):
"""
append update ops to the global states
......@@ -57,7 +60,7 @@ class Evaluator(object):
attrs={
"shape": g_var.shape,
"value": .0,
"data_type": 5,
"dtype": 5,
})
block.append_op(
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
......@@ -93,7 +96,7 @@ class Accuracy(Evaluator):
def _update_ops(self, input, label, k=1, **kwargs):
block = self._main_program.global_block()
topk_out = block.create_var(dtype=input.data_type)
topk_out = block.create_var(dtype=input.dtype)
topk_indices = block.create_var(dtype="int64")
block.append_op(
type="top_k",
......@@ -122,16 +125,16 @@ class Accuracy(Evaluator):
inputs={"X": [self._states["Total"]]},
outputs={"Out": [self._states["Total"]]},
attrs={
"in_data_type": 5, # float32
"out_data_type": 2, #int32
"in_dtype": 5, # float32
"out_dtype": 2, # int32
})
block.append_op(
type="cast",
inputs={"X": [self._states["Correct"]]},
outputs={"Out": [self._states["Correct"]]},
attrs={
"in_data_type": 5,
"out_data_type": 2,
"in_dtype": 5,
"out_dtype": 2,
})
block.append_op(
......@@ -153,7 +156,7 @@ class Accuracy(Evaluator):
else:
eval_program = Program()
block = eval_program.global_block()
eval_out = block.create_var(dtype=self._states["Total"].data_type)
eval_out = block.create_var(dtype=self._states["Total"].dtype)
e_total = _clone_var_in_block_(block, self._states["Total"])
e_correct = _clone_var_in_block_(block, self._states["Correct"])
block.append_op(
......@@ -161,16 +164,16 @@ class Accuracy(Evaluator):
inputs={"X": [e_total]},
outputs={"Out": [e_total]},
attrs={
"in_data_type": 2, #int32
"out_data_type": 5, #float32
"in_dtype": 2, # int32
"out_dtype": 5, # float32
})
block.append_op(
type="cast",
inputs={"X": [e_correct]},
outputs={"Out": [e_correct]},
attrs={
"in_data_type": 2,
"out_data_type": 5,
"in_dtype": 2,
"out_dtype": 5,
})
block.append_op(
type="elementwise_div",
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -92,7 +92,7 @@ class Optimizer(object):
var = self.helper.create_global_variable(
name=unique_name(name),
persistable=True,
dtype=dtype or param.data_type,
dtype=dtype or param.dtype,
type=param.type,
shape=param.shape)
self.helper.set_variable_initializer(
......@@ -202,7 +202,7 @@ class Optimizer(object):
"""
params_grads = append_backward_ops(loss, parameter_list, no_grad_set or
set())
# Add regularization if any
# Add regularization if any
params_grads = append_regularization_ops(params_grads)
optimize_ops = self.create_optimization_pass(params_grads, loss,
startup_program)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册