“824b84d185046e10883b279a5d0289f29fe2e98d”上不存在“paddle/infrt/dialect/phi/infrt_phi_tensor.cc”
提交 5229ccbd 编写于 作者: Y Yang Yang

merge develop

...@@ -15,6 +15,13 @@ ...@@ -15,6 +15,13 @@
include(ExternalProject) include(ExternalProject)
set(BOOST_PROJECT "extern_boost") set(BOOST_PROJECT "extern_boost")
# To release PaddlePaddle as a pip package, we have to follow the
# manylinux1 standard, which features as old Linux kernels and
# compilers as possible and recommends CentOS 5. Indeed, the earliest
# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new
# version of boost, say, 1.66.0, doesn't build on CentOS 6. We
# checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here.
set(BOOST_VER "1.41.0") set(BOOST_VER "1.41.0")
set(BOOST_TAR "boost_1_41_0") set(BOOST_TAR "boost_1_41_0")
set(BOOST_URL "http://paddlepaddledeps.s3-website-us-west-1.amazonaws.com/${BOOST_TAR}.tar.gz") set(BOOST_URL "http://paddlepaddledeps.s3-website-us-west-1.amazonaws.com/${BOOST_TAR}.tar.gz")
......
# Design Doc: Parallel_Do in PaddlePaddle
In PaddlePaddle, we use parallel_do primitive to represent multithread data parallel processing.
## Design overview
The definition of a parallel_do op looks like the following
```c++
AddInput(kInputs, "Inputs needed to be split onto different devices").AsDuplicable();
AddInput(kParameters, "Parameters are duplicated over different devices")
.AsDuplicable();
AddInput(kPlaces, "Devices used for parallel processing");
AddOutput(kOutputs, "Outputs needed to be merged from different devices").AsDuplicable();
AddOutput(kParallelScopes,
"Scopes for all local variables in forward pass. One scope for each device");
AddAttr<framework::BlockDesc *>(kParallelBlock,
"List of operaters to be executed in parallel");
```
A vanilla implementation of parallel_do can be shown as the following (`|` means single thread and
`||||` means multiple threads)
```
In the forward pass
| Split input onto different devices
| Copy parameter to onto different devices
|||| Compute forward pass in parallel
| Merge output from different devices
In the backward pass
| Split output@grad onto different devices
|||| Compute backward pass in parallel
| accumulate param@grad from different devices to the first device
| Merge input@grad from different devices
 | Copy param@grad to the place of parallel_do_op
```
This implementation allows to write mixed device program like this
```python
# get embedding feature on CPU
feature = some_cpu_only_op(data)
gpu_places = get_place(use_gpu=True)
# parallel processing on multiple GPUs
pd = ParallelDo(gpu_places)
with pd.do():
read_input(feature)
prediction = my_net(feature)
write_output(prediction)
prediction = pd()
loss = cross_entropy(prediction, label)
```
And the programDesc are like the following
```
# start_program will be run by executor(CPUPlace), all w1, w2 will be allocated on CPU
start_program
{
vars: w1, w2
ops: init(w1), init(w2)
}
main_program
{
block0 {
vars: data, places, w1, w2
ops: data, get_place, parallel_do(block1),
parallel_do_grad(block2),
sgd(w2, w2_grad),
sgd(w1, w1_grad)
}
block1 {
parent_block: 0
vars: data, h1, h2, loss
ops: fc, fc, softmax
}
block2 {
parent_block: 1
vars: data_grad, h1_grad, h2_grad, loss_gard, w1_grad, w2_grad
ops: softmax_grad,
fc_grad
fc_grad
}
}
```
## Proformance Imporvement
There are serial places we can make this parallel_do faster.
### forward: split input onto different devices
If the input of the parallel_do is independent from any prior opeartors, we can avoid this step by
prefetching the input onto different devices in a seperate background thread. And the python code
looks like this.
```python
pd = ParallelDo(gpu_places)
with pd.do():
   feature = get_data_from_prefetch_queue(gpu_places)
prediction = my_net(feature)
write_output(activation)
```
### forward: Copy parameter to onto different devices
We can avoid this step by making each device have a copy of the parameter. This requires:
1. `fluid.default_start_up_program()` to be run on all devices
1. In the backward, allreduce param@grad at different devices, this requires
1. `backward.py` add `allreduce` operators at parallel_do_grad
1. `allreduce` operators need to be called in async mode to achieve maximum throughput
1. apply gradients related op(i.e. cliping, normalization, decay, sgd) on different devices in parallel
By doing so, we also avoided "backward: accumulate param@grad from different devices to the first device".
And the ProgramDesc looks like the following
```
# w1, w2 will be allocated on all GPUs
start_program
{
block0 {
parallel_do(block1)
}
block1 {
parent_block: 0
vars: w1, w2
ops: init(w1), init(w2)
}
}
main_program
{
block0 {
vars: data, places, w1, w2
ops: data, get_place, parallel_do(block1),
parallel_do_grad(block2), # append_backward
parallel_do(block3) # append_optimization
}
block1 {
parent_block: 0
vars: data, h1, h2, loss
ops: fc, fc, softmax
}
block2 {
parent_block: 1
vars: data_grad, h1_grad, h2_grad, loss_gard, w1_grad, w2_grad
ops: softmax_grad,
fc_grad, allreduce(places, scopes, w1_grad),
fc_grad, allreduce(places, scopes, w2_grad)
}
block3 {
parent_block: 0
vars: lr
ops: sgd(w2, w2_grad),
sgd(w1, w1_grad)
}
}
```
...@@ -37,7 +37,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, ...@@ -37,7 +37,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place); auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait(); dev_ctx->Wait();
Copy(in, dst_place, *dev_ctx, out); TensorCopy(in, dst_place, *dev_ctx, out);
dev_ctx->Wait(); dev_ctx->Wait();
} }
......
...@@ -157,7 +157,7 @@ TEST(Operator, CPUtoGPU) { ...@@ -157,7 +157,7 @@ TEST(Operator, CPUtoGPU) {
auto dev_ctx = pool.Get(cuda_place); auto dev_ctx = pool.Get(cuda_place);
paddle::framework::Tensor output_tensor; paddle::framework::Tensor output_tensor;
Copy(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx, TensorCopy(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor); &output_tensor);
dev_ctx->Wait(); dev_ctx->Wait();
......
...@@ -75,8 +75,10 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -75,8 +75,10 @@ static void CheckTensorNANOrInf(const std::string& name,
tensor.type().hash_code() != typeid(double).hash_code()) { tensor.type().hash_code() != typeid(double).hash_code()) {
return; return;
} }
PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); "Tensor %s contains Inf", name);
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
"Tensor %s contains NAN", name);
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
...@@ -120,12 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -120,12 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(op->Type(), pool.Get(place_)); platform::RecordEvent record_event(op->Type(), pool.Get(place_));
VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
......
...@@ -46,7 +46,7 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -46,7 +46,7 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
if (!platform::is_cpu_place(t.place())) { if (!platform::is_cpu_place(t.place())) {
LoDTensor tt; LoDTensor tt;
framework::Copy(t, platform::CPUPlace(), &tt); framework::TensorCopy(t, platform::CPUPlace(), &tt);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(t.place()); auto &dev_ctx = *pool.Get(t.place());
dev_ctx.Wait(); dev_ctx.Wait();
...@@ -255,7 +255,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, ...@@ -255,7 +255,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
} }
} }
// the 3st field, Tensor // the 3st field, Tensor
SerializeToStream(os, static_cast<Tensor>(tensor), dev_ctx); TensorToStream(os, static_cast<Tensor>(tensor), dev_ctx);
} }
void DeserializeFromStream(std::istream &is, LoDTensor *tensor, void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
...@@ -282,7 +282,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -282,7 +282,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
} }
} }
// the 3st filed, Tensor // the 3st filed, Tensor
DeserializeFromStream(is, static_cast<Tensor *>(tensor), dev_ctx); TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
} }
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
...@@ -308,14 +308,14 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -308,14 +308,14 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if (lod().empty()) { if (lod().empty()) {
auto src = Slice(begin, end); auto src = Slice(begin, end);
auto &dst_place = places[i]; auto &dst_place = places[i];
framework::Copy(src, dst_place, &dst); framework::TensorCopy(src, dst_place, &dst);
} else { } else {
auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0); auto lod_and_offset = GetSubLoDAndAbsoluteOffset(lod(), begin, end, 0);
auto &offset = lod_and_offset.second; auto &offset = lod_and_offset.second;
auto src = Slice(offset.first, offset.second); auto src = Slice(offset.first, offset.second);
auto &dst_place = places[i]; auto &dst_place = places[i];
framework::Copy(src, dst_place, &dst); framework::TensorCopy(src, dst_place, &dst);
LoD my_lod; LoD my_lod;
for (auto &l : lod_and_offset.first) { for (auto &l : lod_and_offset.first) {
...@@ -369,7 +369,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -369,7 +369,7 @@ void LoDTensor::MergeLoDTensor(
for (auto *src : lod_tensors) { for (auto *src : lod_tensors) {
int end = begin + src->dims()[0]; int end = begin + src->dims()[0];
auto dst = Slice(begin, end); auto dst = Slice(begin, end);
framework::Copy(*src, dst_place, &dst); framework::TensorCopy(*src, dst_place, &dst);
begin = end; begin = end;
} }
} }
......
...@@ -175,7 +175,7 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, ...@@ -175,7 +175,7 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
for (size_t ins = 0; ins < num_instances; ins++) { for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) { for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
auto slice = tensor.Slice(elem, elem + 1); auto slice = tensor.Slice(elem, elem + 1);
Copy(source.Slice(ins, ins + 1), platform::CPUPlace(), TensorCopy(source.Slice(ins, ins + 1), platform::CPUPlace(),
platform::CPUDeviceContext(), &slice); platform::CPUDeviceContext(), &slice);
} }
} }
......
...@@ -291,7 +291,7 @@ class Vector { ...@@ -291,7 +291,7 @@ class Vector {
void CopyToCPU() const { void CopyToCPU() const {
// COPY GPU Data To CPU // COPY GPU Data To CPU
Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); TensorCopy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
} }
...@@ -305,13 +305,14 @@ class Vector { ...@@ -305,13 +305,14 @@ class Vector {
void ImmutableCUDA(platform::Place place) const { void ImmutableCUDA(platform::Place place) const {
if (IsDirty()) { if (IsDirty()) {
if (IsInCPU()) { if (IsInCPU()) {
Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_); TensorCopy(cpu_vec_, boost::get<platform::CUDAPlace>(place),
&cuda_vec_);
WaitPlace(place); WaitPlace(place);
UnsetFlag(kDirty); UnsetFlag(kDirty);
SetFlag(kDataInCUDA); SetFlag(kDataInCUDA);
} else if (IsInCUDA() && !(place == cuda_vec_.place())) { } else if (IsInCUDA() && !(place == cuda_vec_.place())) {
framework::Tensor tmp; framework::Tensor tmp;
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp); TensorCopy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
cuda_vec_.ShareDataWith(tmp); cuda_vec_.ShareDataWith(tmp);
// Still dirty // Still dirty
...@@ -322,13 +323,14 @@ class Vector { ...@@ -322,13 +323,14 @@ class Vector {
} else { } else {
if (!IsInCUDA()) { if (!IsInCUDA()) {
// Even data is not dirty. However, data is not in CUDA. Copy data. // Even data is not dirty. However, data is not in CUDA. Copy data.
Copy(cpu_vec_, boost::get<platform::CUDAPlace>(place), &cuda_vec_); TensorCopy(cpu_vec_, boost::get<platform::CUDAPlace>(place),
&cuda_vec_);
WaitPlace(place); WaitPlace(place);
SetFlag(kDataInCUDA); SetFlag(kDataInCUDA);
} else if (!(place == cuda_vec_.place())) { } else if (!(place == cuda_vec_.place())) {
framework::Tensor tmp; framework::Tensor tmp;
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp); TensorCopy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
WaitPlace(place); WaitPlace(place);
cuda_vec_.ShareDataWith(tmp); cuda_vec_.ShareDataWith(tmp);
......
...@@ -105,7 +105,7 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) { ...@@ -105,7 +105,7 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
} }
} }
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst); TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0]; dst_offset += ins_shape[0];
} }
out_tensor.set_lod(batch_lod); out_tensor.set_lod(batch_lod);
......
...@@ -34,7 +34,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, ...@@ -34,7 +34,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
os.write(reinterpret_cast<const char*>(&height), sizeof(height)); os.write(reinterpret_cast<const char*>(&height), sizeof(height));
} }
// the 4st field, Tensor data // the 4st field, Tensor data
SerializeToStream(os, selected_rows.value(), dev_ctx); TensorToStream(os, selected_rows.value(), dev_ctx);
} }
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
...@@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, ...@@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
selected_rows->set_height(height); selected_rows->set_height(height);
} }
// the 4st field, tensor which contains the data // the 4st field, tensor which contains the data
DeserializeFromStream(is, selected_rows->mutable_value(), dev_ctx); TensorFromStream(is, selected_rows->mutable_value(), dev_ctx);
} }
} // namespace framework } // namespace framework
......
...@@ -16,6 +16,76 @@ ...@@ -16,6 +16,76 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) {
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx;
if (platform::is_gpu_place(src.place())) {
dev_ctx = pool.Get(src.place());
} else {
dev_ctx = pool.Get(dst_place);
}
TensorCopy(src, dst_place, *dev_ctx, dst);
}
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor { struct AnyDTypeVisitor {
Predicate predicate_; Predicate predicate_;
...@@ -69,7 +139,7 @@ struct AnyVisitor : public boost::static_visitor<bool> { ...@@ -69,7 +139,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
tmp.mutable_data<bool>(cpu); tmp.mutable_data<bool>(cpu);
auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
gpuctx->Wait(); gpuctx->Wait();
Copy(out, cpu, *gpuctx, &tmp); TensorCopy(out, cpu, *gpuctx, &tmp);
gpuctx->Wait(); gpuctx->Wait();
return GetResult(tmp, cpu); return GetResult(tmp, cpu);
} }
...@@ -87,7 +157,7 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { ...@@ -87,7 +157,7 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return platform::VisitPlace(place, visitor); return platform::VisitPlace(place, visitor);
} }
struct HasNANPredicate { struct ContainsNANPredicate {
template <typename T> template <typename T>
auto operator()(const T& eigen_vec) const auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) { -> decltype(std::declval<T>().isnan()) {
...@@ -96,12 +166,12 @@ struct HasNANPredicate { ...@@ -96,12 +166,12 @@ struct HasNANPredicate {
} }
}; };
bool HasNAN(const framework::Tensor& tensor) { bool TensorContainsNAN(const framework::Tensor& tensor) {
HasNANPredicate predicate; ContainsNANPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
struct HasInfPredicate { struct ContainsInfPredicate {
template <typename T> template <typename T>
auto operator()(const T& eigen_vec) const auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) { -> decltype(std::declval<T>().isinf()) {
...@@ -110,10 +180,124 @@ struct HasInfPredicate { ...@@ -110,10 +180,124 @@ struct HasInfPredicate {
} }
}; };
bool HasInf(const framework::Tensor& tensor) { bool TensorContainsInf(const framework::Tensor& tensor) {
HasInfPredicate predicate; ContainsInfPredicate predicate;
return Any(tensor, predicate); return Any(tensor, predicate);
} }
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char*>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
os.write(static_cast<const char*>(data_ptr),
static_cast<std::streamsize>(size));
}
}
}
struct DeserializedDataFunctor {
DeserializedDataFunctor(void** buf, Tensor* tensor,
const platform::Place& place)
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void operator()() {
*buf_ = tensor_->mutable_data<T>(place_);
}
void** buf_;
Tensor* tensor_;
platform::Place place_;
};
void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx) {
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void* buf;
auto ctx = platform::CPUDeviceContext();
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(dims));
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), tensor->memory_size());
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace framework {
template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor {
Predicate predicate_;
const Tensor& tensor_;
const DevCtx& ctx_;
Tensor* out_;
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
Tensor* out)
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void operator()() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
o.device(*ctx_.eigen_device()) = predicate_(t).any();
}
};
template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}
template <typename Predicate>
struct AnyVisitor : public boost::static_visitor<bool> {
const framework::Tensor& tensor_;
Predicate predicate_;
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place>
bool operator()(const Place& place) const {
framework::Tensor out;
out.Resize({1});
out.mutable_data<bool>(place);
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
AnyImpl(predicate_, tensor_, *ctx, &out);
return this->GetResult(out, place);
}
bool GetResult(const framework::Tensor& out,
const platform::CUDAPlace& gpu) const {
platform::CPUPlace cpu;
framework::Tensor tmp;
tmp.Resize({1});
tmp.mutable_data<bool>(cpu);
auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
gpuctx->Wait();
Copy(out, cpu, *gpuctx, &tmp);
gpuctx->Wait();
return GetResult(tmp, cpu);
}
bool GetResult(const framework::Tensor& out,
const platform::CPUPlace& cpu) const {
return *out.data<bool>();
}
};
template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate);
auto place = tensor.place();
return platform::VisitPlace(place, visitor);
}
struct HasNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) {
// Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isnan();
}
};
bool HasNAN(const framework::Tensor& tensor) {
HasNANPredicate predicate;
return Any(tensor, predicate);
}
struct HasInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) {
// Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isinf();
}
};
bool HasInf(const framework::Tensor& tensor) {
HasInfPredicate predicate;
return Any(tensor, predicate);
}
} // namespace framework
} // namespace paddle
tensor_util.cc
\ No newline at end of file
...@@ -22,105 +22,37 @@ limitations under the License. */ ...@@ -22,105 +22,37 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/** void TensorCopy(const Tensor& src, const platform::Place& dst_place,
* @brief Copy the content of external tensor to a new place. const platform::DeviceContext& ctx, Tensor* dst);
* void TensorCopy(const Tensor& src, const platform::Place& dst_place,
* @param[in] src The external tensor. Tensor* dst);
* @param[in] dst_place The dst place.
* @param[in] ctx The device context contains device resources.
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline void Copy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) {
VLOG(3) << "Copy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
dst->Resize(src.dims()); template <typename T>
dst->set_layout(src.layout()); void TensorFromVector(const std::vector<T>& src,
auto src_place = src.place(); const platform::DeviceContext& ctx, Tensor* dst);
auto src_ptr = src.data<void>(); template <typename T>
void TensorFromVector(const std::vector<T>& src, Tensor* dst);
auto dst_ptr = dst->mutable_data(dst_place, src.type()); template <typename T>
void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
std::vector<T>* dst);
template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst);
auto size = src.numel() * SizeOfType(src.type()); bool TensorContainsNAN(const framework::Tensor& tensor);
bool TensorContainsInf(const framework::Tensor& tensor);
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { void TensorToStream(std::ostream& os, const Tensor& tensor,
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, const platform::DeviceContext& dev_ctx);
boost::get<platform::CPUPlace>(src_place), src_ptr, size); void TensorFromStream(std::istream& is, Tensor* tensor,
} const platform::DeviceContext& dev_ctx);
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
/** //
* @brief Wrapper on // The implementation of template functions.
* Copy(const Tensor& src, const platform::Place& dst_place, //
* const platform::DeviceContext& ctx, Tensor* dst);
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline void Copy(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx;
if (platform::is_gpu_place(src.place())) {
dev_ctx = pool.Get(src.place());
} else {
dev_ctx = pool.Get(dst_place);
}
Copy(src, dst_place, *dev_ctx, dst);
}
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector will resize dst to an 1D tensor with the same
* size as src.
*/
template <typename T> template <typename T>
inline void CopyFromVector(const std::vector<T>& src, void TensorFromVector(const std::vector<T>& src,
const platform::DeviceContext& ctx, Tensor* dst) { const platform::DeviceContext& ctx, Tensor* dst) {
auto dst_place = ctx.GetPlace(); auto dst_place = ctx.GetPlace();
auto src_ptr = static_cast<const void*>(src.data()); auto src_ptr = static_cast<const void*>(src.data());
...@@ -143,11 +75,8 @@ inline void CopyFromVector(const std::vector<T>& src, ...@@ -143,11 +75,8 @@ inline void CopyFromVector(const std::vector<T>& src,
#endif #endif
} }
/**
* @brief CopyFromVector CPU vector -> CPU Tensor
*/
template <typename T> template <typename T>
inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) { void TensorFromVector(const std::vector<T>& src, Tensor* dst) {
platform::CPUPlace dst_place = platform::CPUPlace(); platform::CPUPlace dst_place = platform::CPUPlace();
auto src_ptr = static_cast<const void*>(src.data()); auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place; platform::CPUPlace src_place;
...@@ -158,17 +87,8 @@ inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) { ...@@ -158,17 +87,8 @@ inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} }
/**
* @brief Copy the content of a tensor to a vector
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T> template <typename T>
inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
std::vector<T>* dst) { std::vector<T>* dst) {
auto src_ptr = static_cast<const void*>(src.data<T>()); auto src_ptr = static_cast<const void*>(src.data<T>());
auto size = src.numel() * sizeof(T); auto size = src.numel() * sizeof(T);
...@@ -191,11 +111,8 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, ...@@ -191,11 +111,8 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
#endif #endif
} }
/**
* @brief CopyToVector CPUTensor <-> CPU Vector
*/
template <typename T> template <typename T>
inline void CopyToVector(const Tensor& src, std::vector<T>* dst) { void TensorToVector(const Tensor& src, std::vector<T>* dst) {
auto src_ptr = static_cast<const void*>(src.data<T>()); auto src_ptr = static_cast<const void*>(src.data<T>());
auto size = src.numel() * sizeof(T); auto size = src.numel() * sizeof(T);
...@@ -209,125 +126,5 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) { ...@@ -209,125 +126,5 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr, size); src_ptr, size);
} }
// Returns true if a tensor contains NAN, i.e., Not A Number.
bool HasNAN(const framework::Tensor& tensor);
// Returns true if a tensor contains Inf, i.e., Infinity.
bool HasInf(const framework::Tensor& tensor);
inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char*>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
os.write(static_cast<const char*>(data_ptr),
static_cast<std::streamsize>(size));
}
}
}
struct DeserializedDataFunctor {
DeserializedDataFunctor(void** buf, Tensor* tensor,
const platform::Place& place)
: buf_(buf), tensor_(tensor), place_(place) {}
template <typename T>
void operator()() {
*buf_ = tensor_->mutable_data<T>(place_);
}
void** buf_;
Tensor* tensor_;
platform::Place place_;
};
inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx) {
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void* buf;
auto ctx = platform::CPUDeviceContext();
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(dims));
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
auto dst_place = dev_ctx.GetPlace();
framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor);
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), tensor->memory_size());
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(Copy, Tensor) { TEST(TensorCopy, Tensor) {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
platform::CPUDeviceContext cpu_ctx((platform::CPUPlace())); platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
...@@ -33,7 +33,7 @@ TEST(Copy, Tensor) { ...@@ -33,7 +33,7 @@ TEST(Copy, Tensor) {
src_tensor.set_layout(DataLayout::kAnyLayout); src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace(); auto cpu_place = new platform::CPUPlace();
Copy(src_tensor, *cpu_place, &dst_tensor); TensorCopy(src_tensor, *cpu_place, &dst_tensor);
const int* dst_ptr = dst_tensor.data<int>(); const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr); ASSERT_NE(src_ptr, dst_ptr);
...@@ -44,7 +44,7 @@ TEST(Copy, Tensor) { ...@@ -44,7 +44,7 @@ TEST(Copy, Tensor) {
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout()); EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
Copy(slice_tensor, *cpu_place, &dst_tensor); TensorCopy(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data<int>(); const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>(); dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr); ASSERT_NE(dst_ptr, slice_ptr);
...@@ -68,11 +68,11 @@ TEST(Copy, Tensor) { ...@@ -68,11 +68,11 @@ TEST(Copy, Tensor) {
// CPU Tensor to GPU Tensor // CPU Tensor to GPU Tensor
auto gpu_place = new platform::CUDAPlace(0); auto gpu_place = new platform::CUDAPlace(0);
platform::CUDADeviceContext gpu_ctx(*gpu_place); platform::CUDADeviceContext gpu_ctx(*gpu_place);
Copy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); TensorCopy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
// GPU Tensor to CPU Tensor // GPU Tensor to CPU Tensor
auto cpu_place = new platform::CPUPlace(); auto cpu_place = new platform::CPUPlace();
Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors // Sync before Compare Tensors
gpu_ctx.Wait(); gpu_ctx.Wait();
...@@ -85,10 +85,10 @@ TEST(Copy, Tensor) { ...@@ -85,10 +85,10 @@ TEST(Copy, Tensor) {
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
// CPU Slice Tensor to GPU Tensor // CPU Slice Tensor to GPU Tensor
Copy(slice_tensor, *gpu_place, gpu_ctx, &gpu_tensor); TensorCopy(slice_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
// GPU Tensor to CPU Tensor // GPU Tensor to CPU Tensor
Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Slice Tensors // Sync before Compare Slice Tensors
gpu_ctx.Wait(); gpu_ctx.Wait();
...@@ -104,7 +104,7 @@ TEST(Copy, Tensor) { ...@@ -104,7 +104,7 @@ TEST(Copy, Tensor) {
#endif #endif
} }
TEST(CopyFromVector, Tensor) { TEST(TensorFromVector, Tensor) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
{ {
...@@ -114,7 +114,7 @@ TEST(CopyFromVector, Tensor) { ...@@ -114,7 +114,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to CPU Tensor // Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3})); cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace(); auto cpu_place = new paddle::platform::CPUPlace();
CopyFromVector<int>(src_vec, &cpu_tensor); TensorFromVector<int>(src_vec, &cpu_tensor);
// Compare Tensors // Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>(); const int* cpu_ptr = cpu_tensor.data<int>();
...@@ -126,7 +126,7 @@ TEST(CopyFromVector, Tensor) { ...@@ -126,7 +126,7 @@ TEST(CopyFromVector, Tensor) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5); src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2})); cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, &cpu_tensor); TensorFromVector<int>(src_vec, &cpu_tensor);
cpu_ptr = cpu_tensor.data<int>(); cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data(); src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr); ASSERT_NE(src_ptr, cpu_ptr);
...@@ -148,15 +148,15 @@ TEST(CopyFromVector, Tensor) { ...@@ -148,15 +148,15 @@ TEST(CopyFromVector, Tensor) {
cpu_tensor.Resize(make_ddim({3, 3})); cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace(); auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place); CPUDeviceContext cpu_ctx(*cpu_place);
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor); TensorFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
// Copy to GPUTensor // Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3})); gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::CUDAPlace(); auto gpu_place = new paddle::platform::CUDAPlace();
CUDADeviceContext gpu_ctx(*gpu_place); CUDADeviceContext gpu_ctx(*gpu_place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor); TensorFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
// Copy from GPU to CPU tensor for comparison // Copy from GPU to CPU tensor for comparison
Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors // Sync before Compare Tensors
gpu_ctx.Wait(); gpu_ctx.Wait();
...@@ -173,10 +173,10 @@ TEST(CopyFromVector, Tensor) { ...@@ -173,10 +173,10 @@ TEST(CopyFromVector, Tensor) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5); src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2})); cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor); TensorFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
gpu_tensor.Resize(make_ddim({2, 2})); gpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor); TensorFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
Copy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor); TensorCopy(gpu_tensor, *cpu_place, gpu_ctx, &dst_tensor);
// Sync before Compare Tensors // Sync before Compare Tensors
gpu_ctx.Wait(); gpu_ctx.Wait();
...@@ -196,7 +196,7 @@ TEST(CopyFromVector, Tensor) { ...@@ -196,7 +196,7 @@ TEST(CopyFromVector, Tensor) {
#endif #endif
} }
TEST(CopyToVector, Tensor) { TEST(TensorToVector, Tensor) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
{ {
...@@ -208,7 +208,7 @@ TEST(CopyToVector, Tensor) { ...@@ -208,7 +208,7 @@ TEST(CopyToVector, Tensor) {
CPUPlace place; CPUPlace place;
std::vector<int> dst; std::vector<int> dst;
CopyToVector<int>(src, &dst); TensorToVector<int>(src, &dst);
for (int i = 0; i < 3 * 3; ++i) { for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_ptr[i], dst[i]); EXPECT_EQ(src_ptr[i], dst[i]);
...@@ -220,10 +220,10 @@ TEST(CopyToVector, Tensor) { ...@@ -220,10 +220,10 @@ TEST(CopyToVector, Tensor) {
Tensor gpu_tensor; Tensor gpu_tensor;
CUDAPlace place; CUDAPlace place;
CUDADeviceContext gpu_ctx(place); CUDADeviceContext gpu_ctx(place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor); TensorFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
std::vector<int> dst; std::vector<int> dst;
CopyToVector<int>(gpu_tensor, gpu_ctx, &dst); TensorToVector<int>(gpu_tensor, gpu_ctx, &dst);
for (int i = 0; i < 3 * 3; ++i) { for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_vec[i], dst[i]); EXPECT_EQ(src_vec[i], dst[i]);
...@@ -232,7 +232,7 @@ TEST(CopyToVector, Tensor) { ...@@ -232,7 +232,7 @@ TEST(CopyToVector, Tensor) {
#endif #endif
} }
TEST(HasNAN, CPU) { TEST(TensorContainsNAN, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; Tensor src;
...@@ -240,11 +240,12 @@ TEST(HasNAN, CPU) { ...@@ -240,11 +240,12 @@ TEST(HasNAN, CPU) {
buf[0] = 0.0; buf[0] = 0.0;
buf[1] = NAN; buf[1] = NAN;
buf[2] = 0.0; buf[2] = 0.0;
ASSERT_TRUE(TensorContainsNAN(src));
ASSERT_TRUE(HasNAN(src)); buf[1] = 0.0;
ASSERT_FALSE(TensorContainsNAN(src));
} }
TEST(HasInf, CPU) { TEST(TensorContainsInf, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; Tensor src;
...@@ -252,10 +253,12 @@ TEST(HasInf, CPU) { ...@@ -252,10 +253,12 @@ TEST(HasInf, CPU) {
buf[0] = 1.0; buf[0] = 1.0;
buf[1] = INFINITY; buf[1] = INFINITY;
buf[2] = 0.0; buf[2] = 0.0;
ASSERT_TRUE(HasInf(src)); ASSERT_TRUE(TensorContainsInf(src));
buf[1] = 1.0;
ASSERT_FALSE(TensorContainsInf(src));
} }
TEST(Tensor, SerializeAndDeserialize) { TEST(Tensor, FromAndToStream) {
framework::Tensor src_tensor; framework::Tensor src_tensor;
int array[6] = {1, 2, 3, 4, 5, 6}; int array[6] = {1, 2, 3, 4, 5, 6};
src_tensor.Resize({2, 3}); src_tensor.Resize({2, 3});
...@@ -268,10 +271,10 @@ TEST(Tensor, SerializeAndDeserialize) { ...@@ -268,10 +271,10 @@ TEST(Tensor, SerializeAndDeserialize) {
auto place = new platform::CPUPlace(); auto place = new platform::CPUPlace();
platform::CPUDeviceContext cpu_ctx(*place); platform::CPUDeviceContext cpu_ctx(*place);
std::ostringstream oss; std::ostringstream oss;
SerializeToStream(oss, src_tensor, cpu_ctx); TensorToStream(oss, src_tensor, cpu_ctx);
std::istringstream iss(oss.str()); std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor, cpu_ctx); TensorFromStream(iss, &dst_tensor, cpu_ctx);
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace()); int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
ASSERT_EQ(dst_ptr[i], array[i]); ASSERT_EQ(dst_ptr[i], array[i]);
...@@ -288,13 +291,13 @@ TEST(Tensor, SerializeAndDeserialize) { ...@@ -288,13 +291,13 @@ TEST(Tensor, SerializeAndDeserialize) {
auto gpu_place = new platform::CUDAPlace(); auto gpu_place = new platform::CUDAPlace();
platform::CUDADeviceContext gpu_ctx(*gpu_place); platform::CUDADeviceContext gpu_ctx(*gpu_place);
Copy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); TensorCopy(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
std::ostringstream oss; std::ostringstream oss;
SerializeToStream(oss, gpu_tensor, gpu_ctx); TensorToStream(oss, gpu_tensor, gpu_ctx);
std::istringstream iss(oss.str()); std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor, gpu_ctx); TensorFromStream(iss, &dst_tensor, gpu_ctx);
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace()); int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
......
...@@ -31,7 +31,7 @@ static __global__ void FillInf(float* buf) { ...@@ -31,7 +31,7 @@ static __global__ void FillInf(float* buf) {
buf[2] = 0.5; buf[2] = 0.5;
} }
TEST(HasNAN, GPU) { TEST(TensorContainsNAN, GPU) {
Tensor tensor; Tensor tensor;
platform::CUDAPlace gpu(0); platform::CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
...@@ -39,10 +39,10 @@ TEST(HasNAN, GPU) { ...@@ -39,10 +39,10 @@ TEST(HasNAN, GPU) {
float* buf = tensor.mutable_data<float>({3}, gpu); float* buf = tensor.mutable_data<float>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait(); cuda_ctx->Wait();
ASSERT_TRUE(HasNAN(tensor)); ASSERT_TRUE(TensorContainsNAN(tensor));
} }
TEST(HasInf, GPU) { TEST(TensorContainsInf, GPU) {
Tensor tensor; Tensor tensor;
platform::CUDAPlace gpu(0); platform::CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
...@@ -50,7 +50,7 @@ TEST(HasInf, GPU) { ...@@ -50,7 +50,7 @@ TEST(HasInf, GPU) {
float* buf = tensor.mutable_data<float>({3}, gpu); float* buf = tensor.mutable_data<float>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait(); cuda_ctx->Wait();
ASSERT_TRUE(HasInf(tensor)); ASSERT_TRUE(TensorContainsInf(tensor));
} }
} // namespace framework } // namespace framework
......
...@@ -64,7 +64,6 @@ class ThreadPool { ...@@ -64,7 +64,6 @@ class ThreadPool {
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> { Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try { try {
fn(); fn();
return nullptr;
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
return std::unique_ptr<platform::EnforceNotMet>( return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex)); new platform::EnforceNotMet(ex));
...@@ -73,6 +72,7 @@ class ThreadPool { ...@@ -73,6 +72,7 @@ class ThreadPool {
<< "Unexpected exception is catched in thread pool. All " << "Unexpected exception is catched in thread pool. All "
"throwable exception in Fluid should be an EnforceNotMet."; "throwable exception in Fluid should be an EnforceNotMet.";
} }
return nullptr;
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
......
...@@ -176,6 +176,20 @@ op_library(pool_op SRCS pool_op.cc DEPS pooling) ...@@ -176,6 +176,20 @@ op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col) op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif() endif()
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
op_library(fill_constant_batch_size_like_op
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
DEPS batch_size_like)
op_library(uniform_random_batch_size_like_op
SRCS uniform_random_batch_size_like_op.cc
DEPS batch_size_like uniform_random_op)
op_library(gaussian_random_batch_size_like_op
SRCS gaussian_random_batch_size_like_op.cc
DEPS batch_size_like gaussian_random_op)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
......
...@@ -42,7 +42,7 @@ class ArrayOp : public framework::OperatorBase { ...@@ -42,7 +42,7 @@ class ArrayOp : public framework::OperatorBase {
if (platform::is_gpu_place(i_tensor.place())) { if (platform::is_gpu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU // FIXME: Avoid copy from GPU to CPU
framework::Tensor t; framework::Tensor t;
framework::Copy(i_tensor, platform::CPUPlace(), dev_ctx, &t); framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t);
dev_ctx.Wait(); dev_ctx.Wait();
offset = static_cast<size_t>(*t.data<int64_t>()); offset = static_cast<size_t>(*t.data<int64_t>());
} else { } else {
......
...@@ -112,7 +112,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -112,7 +112,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::Copy(x[x_idx].Slice(start_offset, end_offset), place, framework::TensorCopy(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice); dev_ctx, &slice);
out_offset += len; out_offset += len;
} }
......
...@@ -45,7 +45,7 @@ class AssignFunctor { ...@@ -45,7 +45,7 @@ class AssignFunctor {
out_rows.set_height(rows.height()); out_rows.set_height(rows.height());
auto &t = rows.value(); auto &t = rows.value();
auto *m = out_rows.mutable_value(); auto *m = out_rows.mutable_value();
framework::Copy(t, t.place(), dev_ctx_, m); framework::TensorCopy(t, t.place(), dev_ctx_, m);
} }
template <typename T> template <typename T>
...@@ -57,7 +57,7 @@ class AssignFunctor { ...@@ -57,7 +57,7 @@ class AssignFunctor {
void copy_tensor(const framework::LoDTensor &lod_tensor, void copy_tensor(const framework::LoDTensor &lod_tensor,
framework::LoDTensor *out) const { framework::LoDTensor *out) const {
auto &out_tensor = *out; auto &out_tensor = *out;
Copy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod()); out_tensor.set_lod(lod_tensor.lod());
} }
......
...@@ -41,7 +41,7 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,7 @@ class AssignValueKernel : public framework::OpKernel<T> {
break; break;
} }
auto values = ctx.Attr<std::vector<T>>(value_name); auto values = ctx.Attr<std::vector<T>>(value_name);
framework::CopyFromVector(values, ctx.device_context(), out); framework::TensorFromVector(values, ctx.device_context(), out);
out->Resize(framework::make_ddim(shape)); out->Resize(framework::make_ddim(shape));
} }
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
void BatchSizeLikeOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of %s should not be null.", Type());
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of %s should not be null.",
Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
BatchSizeLikeOpMaker::BatchSizeLikeOpMaker(OpProto *proto,
OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class BatchSizeLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
};
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker);
};
} // namespace operators
} // namespace paddle
...@@ -232,12 +232,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor( ...@@ -232,12 +232,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
id_tensor->set_lod(lod); id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); id_tensor->Resize({static_cast<int64_t>(id_data.size())});
id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace()); id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace());
framework::CopyFromVector<int64_t>(id_data, cpu_ctx, id_tensor); framework::TensorFromVector<int64_t>(id_data, cpu_ctx, id_tensor);
score_tensor->set_lod(lod); score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())}); score_tensor->Resize({static_cast<int64_t>(score_data.size())});
score_tensor->mutable_data<T>(paddle::platform::CPUPlace()); score_tensor->mutable_data<T>(paddle::platform::CPUPlace());
framework::CopyFromVector<T>(score_data, cpu_ctx, score_tensor); framework::TensorFromVector<T>(score_data, cpu_ctx, score_tensor);
} }
template <typename T> template <typename T>
......
...@@ -67,7 +67,6 @@ class CompareOpKernel ...@@ -67,7 +67,6 @@ class CompareOpKernel
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis, ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
Functor(), z); Functor(), z);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -98,15 +98,15 @@ class DetectionOutputKernel : public framework::OpKernel<T> { ...@@ -98,15 +98,15 @@ class DetectionOutputKernel : public framework::OpKernel<T> {
T* conf_data = conf_tensor.data<T>(); T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) { if (platform::is_gpu_place(context.GetPlace())) {
loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace()); loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace());
framework::Copy(loc_tensor, platform::CPUPlace(), framework::TensorCopy(loc_tensor, platform::CPUPlace(),
context.device_context(), &loc_cpu); context.device_context(), &loc_cpu);
loc_data = loc_cpu.data<T>(); loc_data = loc_cpu.data<T>();
conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace()); conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace());
framework::Copy(conf_tensor, platform::CPUPlace(), framework::TensorCopy(conf_tensor, platform::CPUPlace(),
context.device_context(), &conf_cpu); context.device_context(), &conf_cpu);
conf_data = conf_cpu.data<T>(); conf_data = conf_cpu.data<T>();
priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace()); priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace());
framework::Copy(*in_priorbox, platform::CPUPlace(), framework::TensorCopy(*in_priorbox, platform::CPUPlace(),
context.device_context(), &priorbox_cpu); context.device_context(), &priorbox_cpu);
priorbox_data = priorbox_cpu.data<T>(); priorbox_data = priorbox_cpu.data<T>();
} }
...@@ -158,8 +158,8 @@ class DetectionOutputKernel : public framework::OpKernel<T> { ...@@ -158,8 +158,8 @@ class DetectionOutputKernel : public framework::OpKernel<T> {
batch_size, all_indices, all_decoded_bboxes, batch_size, all_indices, all_decoded_bboxes,
out_data); out_data);
if (platform::is_gpu_place(context.GetPlace())) { if (platform::is_gpu_place(context.GetPlace())) {
framework::Copy(out_cpu, platform::CUDAPlace(), context.device_context(), framework::TensorCopy(out_cpu, platform::CUDAPlace(),
out); context.device_context(), out);
} }
} }
}; };
......
...@@ -126,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -126,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
framework::Copy(*in0, context.GetPlace(), context.device_context(), out0); framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else { } else {
switch (dims) { switch (dims) {
REP_EXPAND_GRAD_TEMPLATE(72) REP_EXPAND_GRAD_TEMPLATE(72)
......
...@@ -57,7 +57,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -57,7 +57,7 @@ class FeedOp : public framework::OperatorBase {
if (platform::is_same_place(feed_item.place(), place)) { if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item); out_item->ShareDataWith(feed_item);
} else { } else {
framework::Copy(feed_item, place, dev_ctx, out_item); framework::TensorCopy(feed_item, place, dev_ctx, out_item);
} }
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
} }
......
...@@ -56,7 +56,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -56,7 +56,7 @@ class FetchOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place()); auto &dev_ctx = *pool.Get(src_item.place());
Copy(src_item, platform::CPUPlace(), dev_ctx, &dst_item); TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait(); dev_ctx.Wait();
dst_item.set_lod(src_item.lod()); dst_item.set_lod(src_item.lod());
......
...@@ -13,42 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,42 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_batch_size_like_op.h" #include "paddle/fluid/operators/fill_constant_batch_size_like_op.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Input"),
"Input(Input) of FillConstantBatchSizeLikeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FillConstantBatchSizeLikeOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
protected: protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
...@@ -57,28 +29,14 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -57,28 +29,14 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
} }
}; };
class FillConstantBatchSizeLikeOpMaker class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
: public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddInput("Input",
"(Tensor) Tensor "
"whose dim_idx th dimension is used to specify the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -74,7 +74,7 @@ class FillOp : public framework::OperatorBase { ...@@ -74,7 +74,7 @@ class FillOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::Copy(tensor, place, dev_ctx, &out); framework::TensorCopy(tensor, place, dev_ctx, &out);
} }
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
.SetDefault(.0f);
AddAttr<float>("std",
"(float, default 1.0) "
"std of random tensor.")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed of generator."
"0 means use system wide seed."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
Used to initialize tensors with gaussian random generator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
...@@ -88,7 +88,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,7 +88,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed.") "0 means use system wide seed."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
...@@ -110,4 +112,8 @@ Used to initialize tensors with gaussian random generator. ...@@ -110,4 +112,8 @@ Used to initialize tensors with gaussian random generator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
...@@ -61,4 +61,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -61,4 +61,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(gaussian_random, REGISTER_OP_CUDA_KERNEL(gaussian_random,
paddle::operators::GPUGaussianRandomKernel<float>); paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
...@@ -196,7 +196,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -196,7 +196,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
// dy_dx // dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, scale, /*axis*/ 1, MulFunctor<T>(), &temp); ctx, &d_y, scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx // dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec); row_mean(dev_ctx, temp, &temp_vec);
...@@ -208,7 +208,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -208,7 +208,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp); ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else { } else {
// dy_dx // dy_dx
framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx // dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec); row_mean(dev_ctx, d_y, &temp_vec);
......
...@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase {
out_var->Clear(); out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod()); tensor->set_lod(cpu_tensor.lod());
Copy(cpu_tensor, place, dev_ctx, tensor); TensorCopy(cpu_tensor, place, dev_ctx, tensor);
} }
} }
} }
......
...@@ -55,7 +55,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -55,7 +55,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear(); out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod()); tensor->set_lod(cpu_tensor.lod());
Copy(cpu_tensor, place, dev_ctx, tensor); TensorCopy(cpu_tensor, place, dev_ctx, tensor);
} }
} }
}; };
......
...@@ -33,8 +33,8 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -33,8 +33,8 @@ class LoDResetKernel : public framework::OpKernel<T> {
auto* lod = lod_t->data<int>(); auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu; framework::Tensor lod_cpu;
framework::Copy(*lod_t, platform::CPUPlace(), ctx.device_context(), framework::TensorCopy(*lod_t, platform::CPUPlace(),
&lod_cpu); ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>(); lod = lod_cpu.data<int>();
} }
level0 = std::vector<int>(lod, lod + lod_t->numel()); level0 = std::vector<int>(lod, lod + lod_t->numel());
......
...@@ -94,7 +94,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -94,7 +94,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::Copy(x.Slice(static_cast<int>(each_range.begin), framework::TensorCopy(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)), static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice); x.place(), dev_ctx, &slice);
offset += len; offset += len;
......
...@@ -149,7 +149,8 @@ class ContextProjectFunctor { ...@@ -149,7 +149,8 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t.Slice(k * context_length, Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size); k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size); Tensor w_sub = padding_data.Slice(k, k + padding_size);
framework::Copy(w_sub, context.GetPlace(), context, &out_t_sub); framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
} }
} }
if (down_pad > 0) { // add down pad if (down_pad > 0) { // add down pad
...@@ -179,7 +180,8 @@ class ContextProjectFunctor { ...@@ -179,7 +180,8 @@ class ContextProjectFunctor {
(down_pad_begin_row + t) * context_length); (down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice( Tensor w_sub = padding_data.Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size); up_pad + padding_idx, up_pad + padding_idx + padding_size);
framework::Copy(w_sub, context.GetPlace(), context, &out_t_sub); framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
} }
} }
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
......
...@@ -62,7 +62,7 @@ void testIm2col() { ...@@ -62,7 +62,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
} else { } else {
Copy(input_tmp, *place, *context, &input); TensorCopy(input_tmp, *place, *context, &input);
} }
output_cfo.mutable_data<float>( output_cfo.mutable_data<float>(
{1, filter_size, filter_size, output_height, output_width}, *place); {1, filter_size, filter_size, output_height, output_width}, *place);
...@@ -87,7 +87,7 @@ void testIm2col() { ...@@ -87,7 +87,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>(); out_cfo_ptr = output_cfo.data<float>();
} else { } else {
Copy(output_cfo, paddle::platform::CPUPlace(), *context, &output_tmp); TensorCopy(output_cfo, paddle::platform::CPUPlace(), *context, &output_tmp);
out_cfo_ptr = output_tmp.data<float>(); out_cfo_ptr = output_tmp.data<float>();
} }
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
...@@ -98,7 +98,7 @@ void testIm2col() { ...@@ -98,7 +98,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
out_ocf_ptr = output_ocf.data<float>(); out_ocf_ptr = output_ocf.data<float>();
} else { } else {
Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp); TensorCopy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp);
out_ocf_ptr = output_tmp.data<float>(); out_ocf_ptr = output_tmp.data<float>();
} }
...@@ -119,7 +119,7 @@ void testIm2col() { ...@@ -119,7 +119,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
} else { } else {
Copy(input_tmp, *place, *context, &input); TensorCopy(input_tmp, *place, *context, &input);
} }
col2im(*context, output_cfo, dilation, stride, padding, &input); col2im(*context, output_cfo, dilation, stride, padding, &input);
...@@ -128,7 +128,7 @@ void testIm2col() { ...@@ -128,7 +128,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
} else { } else {
Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>(); in_ptr = input_tmp.data<float>();
} }
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
...@@ -140,7 +140,7 @@ void testIm2col() { ...@@ -140,7 +140,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
} else { } else {
Copy(input_tmp, *place, *context, &input); TensorCopy(input_tmp, *place, *context, &input);
} }
col2im_ocf(*context, output_ocf, dilation, stride, padding, &input); col2im_ocf(*context, output_ocf, dilation, stride, padding, &input);
...@@ -148,7 +148,7 @@ void testIm2col() { ...@@ -148,7 +148,7 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
} else { } else {
Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>(); in_ptr = input_tmp.data<float>();
} }
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
......
...@@ -29,15 +29,15 @@ TEST(math_function, notrans_mul_trans) { ...@@ -29,15 +29,15 @@ TEST(math_function, notrans_mul_trans) {
auto* gpu_place = new paddle::platform::CUDAPlace(0); auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place); paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu);
paddle::framework::Copy(input1, *gpu_place, context, &input2_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu);
out_gpu.mutable_data<float>({2, 2}, *gpu_place); out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>(
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
paddle::framework::Copy(out_gpu, *cpu_place, context, &out); paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out);
float* out_ptr = out.data<float>(); float* out_ptr = out.data<float>();
context.Wait(); context.Wait();
...@@ -63,15 +63,15 @@ TEST(math_function, trans_mul_notrans) { ...@@ -63,15 +63,15 @@ TEST(math_function, trans_mul_notrans) {
auto* gpu_place = new paddle::platform::CUDAPlace(0); auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place); paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu);
paddle::framework::Copy(input1, *gpu_place, context, &input2_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu);
out_gpu.mutable_data<float>({3, 3}, *gpu_place); out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>(
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
paddle::framework::Copy(out_gpu, *cpu_place, context, &out); paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out);
float* out_ptr = out.data<float>(); float* out_ptr = out.data<float>();
context.Wait(); context.Wait();
...@@ -112,9 +112,9 @@ TEST(math_function, gemm_notrans_cublas) { ...@@ -112,9 +112,9 @@ TEST(math_function, gemm_notrans_cublas) {
auto* gpu_place = new paddle::platform::CUDAPlace(0); auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place); paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu);
paddle::framework::Copy(input2, *gpu_place, context, &input2_gpu); paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu);
paddle::framework::Copy(input3, *gpu_place, context, &input3_gpu); paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu);
float* a = input1_gpu.data<float>(); float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place); float* c = input3_gpu.mutable_data<float>(*gpu_place);
...@@ -122,7 +122,7 @@ TEST(math_function, gemm_notrans_cublas) { ...@@ -122,7 +122,7 @@ TEST(math_function, gemm_notrans_cublas) {
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
paddle::framework::Copy(input3_gpu, *cpu_place, context, &input3); paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3);
// numpy code: // numpy code:
// a = np.arange(6).reshape(2, 3) // a = np.arange(6).reshape(2, 3)
...@@ -167,9 +167,9 @@ TEST(math_function, gemm_trans_cublas) { ...@@ -167,9 +167,9 @@ TEST(math_function, gemm_trans_cublas) {
auto* gpu_place = new paddle::platform::CUDAPlace(0); auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place); paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::Copy(input1, *gpu_place, context, &input1_gpu); paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu);
paddle::framework::Copy(input2, *gpu_place, context, &input2_gpu); paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu);
paddle::framework::Copy(input3, *gpu_place, context, &input3_gpu); paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu);
float* a = input1_gpu.data<float>(); float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place); float* c = input3_gpu.mutable_data<float>(*gpu_place);
...@@ -177,7 +177,7 @@ TEST(math_function, gemm_trans_cublas) { ...@@ -177,7 +177,7 @@ TEST(math_function, gemm_trans_cublas) {
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
paddle::framework::Copy(input3_gpu, *cpu_place, context, &input3); paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3);
context.Wait(); context.Wait();
EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[0], 0);
...@@ -218,14 +218,14 @@ void GemvTest(int m, int n, bool trans) { ...@@ -218,14 +218,14 @@ void GemvTest(int m, int n, bool trans) {
} }
paddle::platform::CUDADeviceContext context(*gpu_place); paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::Copy(mat_a, *gpu_place, context, &g_mat_a); paddle::framework::TensorCopy(mat_a, *gpu_place, context, &g_mat_a);
paddle::framework::Copy(vec_b, *gpu_place, context, &g_vec_b); paddle::framework::TensorCopy(vec_b, *gpu_place, context, &g_vec_b);
paddle::operators::math::gemv<paddle::platform::CUDADeviceContext, T>( paddle::operators::math::gemv<paddle::platform::CUDADeviceContext, T>(
context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a, context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a,
g_data_b, 0., g_data_c); g_data_b, 0., g_data_c);
paddle::framework::Copy(g_vec_c, paddle::platform::CPUPlace(), context, paddle::framework::TensorCopy(g_vec_c, paddle::platform::CPUPlace(), context,
&vec_c); &vec_c);
if (!trans) { if (!trans) {
......
...@@ -67,7 +67,7 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -67,7 +67,7 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ(out_rows[6], 9); EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu; Tensor out_cpu;
Copy(*out_value, cpu_place, ctx, &out_cpu); TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait(); ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>(); auto* out_cpu_data = out_cpu.data<float>();
...@@ -94,7 +94,7 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -94,7 +94,7 @@ TEST(selected_rows_functor, gpu_add) {
add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
Tensor tensor2_cpu; Tensor tensor2_cpu;
Copy(*tensor2, cpu_place, ctx, &tensor2_cpu); TensorCopy(*tensor2, cpu_place, ctx, &tensor2_cpu);
ctx.Wait(); ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>(); auto* tensor2_cpu_data = tensor2_cpu.data<float>();
...@@ -167,7 +167,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -167,7 +167,7 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(out_rows[6], 9); EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu; Tensor out_cpu;
Copy(*out_value, cpu_place, ctx, &out_cpu); TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait(); ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>(); auto* out_cpu_data = out_cpu.data<float>();
...@@ -191,7 +191,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -191,7 +191,7 @@ TEST(selected_rows_functor, gpu_add_to) {
add_to_tensor_functor(ctx, *output, tensor1.get()); add_to_tensor_functor(ctx, *output, tensor1.get());
Tensor tensor1_cpu; Tensor tensor1_cpu;
Copy(*tensor1, cpu_place, ctx, &tensor1_cpu); TensorCopy(*tensor1, cpu_place, ctx, &tensor1_cpu);
ctx.Wait(); ctx.Wait();
auto* tensor1_cpu_data = tensor1_cpu.data<float>(); auto* tensor1_cpu_data = tensor1_cpu.data<float>();
......
...@@ -97,7 +97,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -97,7 +97,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq."); "width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) { if (!norm_by_times && num_sequences == 1UL) {
Copy(seq, context.GetPlace(), context, &padding); TensorCopy(seq, context.GetPlace(), context, &padding);
padding.Resize(padding_dims); padding.Resize(padding_dims);
return; return;
} }
...@@ -172,7 +172,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -172,7 +172,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq."); "width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) { if (!norm_by_times && num_sequences == 1UL) {
Copy(padding, context.GetPlace(), context, &seq); TensorCopy(padding, context.GetPlace(), context, &seq);
seq.Resize(seq_dims); seq.Resize(seq_dims);
return; return;
} }
......
...@@ -40,7 +40,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, ...@@ -40,7 +40,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
seq = cpu_seq; seq = cpu_seq;
} else { } else {
Copy(cpu_seq, *place, *context, &seq); TensorCopy(cpu_seq, *place, *context, &seq);
seq.set_lod(lod); seq.set_lod(lod);
} }
...@@ -63,7 +63,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, ...@@ -63,7 +63,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back; cpu_seq_back = seq_back;
} else { } else {
Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back); TensorCopy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back);
cpu_seq_back.set_lod(lod); cpu_seq_back.set_lod(lod);
} }
......
...@@ -71,7 +71,7 @@ void testVol2col() { ...@@ -71,7 +71,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
} else { } else {
Copy(input_tmp, *place, *context, &input); paddle::framework::TensorCopy(input_tmp, *place, *context, &input);
} }
output.mutable_data<float>({1, filter_size, filter_size, filter_size, output.mutable_data<float>({1, filter_size, filter_size, filter_size,
output_depth, output_height, output_width}, output_depth, output_height, output_width},
...@@ -85,7 +85,7 @@ void testVol2col() { ...@@ -85,7 +85,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output.data<float>(); out_cfo_ptr = output.data<float>();
} else { } else {
Copy(output, paddle::platform::CPUPlace(), *context, &output_tmp); TensorCopy(output, paddle::platform::CPUPlace(), *context, &output_tmp);
out_cfo_ptr = output_tmp.data<float>(); out_cfo_ptr = output_tmp.data<float>();
} }
...@@ -99,7 +99,7 @@ void testVol2col() { ...@@ -99,7 +99,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
} else { } else {
Copy(input_tmp, *place, *context, &input); TensorCopy(input_tmp, *place, *context, &input);
} }
paddle::operators::math::Col2VolFunctor<DeviceContext, float> col2vol; paddle::operators::math::Col2VolFunctor<DeviceContext, float> col2vol;
...@@ -109,7 +109,7 @@ void testVol2col() { ...@@ -109,7 +109,7 @@ void testVol2col() {
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
} else { } else {
Copy(input, paddle::platform::CPUPlace(), *context, &input_tmp); TensorCopy(input, paddle::platform::CPUPlace(), *context, &input_tmp);
in_ptr = input_tmp.data<float>(); in_ptr = input_tmp.data<float>();
} }
......
...@@ -51,7 +51,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -51,7 +51,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
cpu_mask->ShareDataWith(mask); cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) { } else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
framework::Copy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx,
cpu_mask.get());
#else #else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif #endif
...@@ -106,8 +107,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -106,8 +107,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
continue; continue;
} }
auto slice = out->Slice(out_offset, out_offset + len); auto slice = out->Slice(out_offset, out_offset + len);
framework::Copy(input->Slice(start_offset, end_offset), place, dev_ctx, framework::TensorCopy(input->Slice(start_offset, end_offset), place,
&slice); dev_ctx, &slice);
out_offset += len; out_offset += len;
(*in_idx) += 1; (*in_idx) += 1;
} }
......
...@@ -67,7 +67,8 @@ class MineHardExamplesKernel : public framework::OpKernel<T> { ...@@ -67,7 +67,8 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
auto out_match_indices = auto out_match_indices =
ctx.Output<framework::Tensor>("UpdatedMatchIndices"); ctx.Output<framework::Tensor>("UpdatedMatchIndices");
framework::Copy(*in_matched_indices, ctx.GetPlace(), out_match_indices); framework::TensorCopy(*in_matched_indices, ctx.GetPlace(),
out_match_indices);
int batch_size = in_matched_indices->dims()[0]; int batch_size = in_matched_indices->dims()[0];
int prior_num = in_matched_indices->dims()[1]; int prior_num = in_matched_indices->dims()[1];
......
...@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
Copy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
...@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
Copy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
......
...@@ -98,7 +98,7 @@ class NCCLTester : public ::testing::Test { ...@@ -98,7 +98,7 @@ class NCCLTester : public ::testing::Test {
send_tensor->mutable_data<T>(kDims, place); send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id); std::vector<T> send_vector(f::product(kDims), gpu_id);
paddle::framework::CopyFromVector<T>(send_vector, *ctx, send_tensor); paddle::framework::TensorFromVector<T>(send_vector, *ctx, send_tensor);
ctx->Wait(); ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel(); VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
} }
......
...@@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst->GetMutable<LoDTensor>()->ShareDataWith(src.Get<LoDTensor>()); dst->GetMutable<LoDTensor>()->ShareDataWith(src.Get<LoDTensor>());
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod()); dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); TensorCopy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
} }
} else if (src.IsType<SelectedRows>()) { } else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>(); auto &src_sr = src.Get<SelectedRows>();
...@@ -89,7 +89,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -89,7 +89,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr->mutable_value()->ShareDataWith(src_sr.value()); dst_sr->mutable_value()->ShareDataWith(src_sr.value());
dst_sr->set_rows(src_sr.rows()); dst_sr->set_rows(src_sr.rows());
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); TensorCopy(src_sr.value(), dst_place, dst_sr->mutable_value());
} }
} else { } else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
...@@ -147,7 +147,7 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -147,7 +147,7 @@ class ParallelDoOp : public framework::OperatorBase {
auto &place = places[i]; auto &place = places[i];
auto *sub_scope = sub_scopes[i]; auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>(); auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::Copy(src, place, dst); framework::TensorCopy(src, place, dst);
} }
} }
WaitOnPlaces(places); WaitOnPlaces(places);
......
...@@ -179,7 +179,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -179,7 +179,7 @@ class TensorPrintOp : public framework::OperatorBase {
} else { } else {
// copy data to cpu to print // copy data to cpu to print
platform::CPUPlace place; platform::CPUPlace place;
framework::Copy(in_tensor, place, &printed_tensor); framework::TensorCopy(in_tensor, place, &printed_tensor);
} }
Formater formater; Formater formater;
......
...@@ -291,7 +291,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -291,7 +291,7 @@ class RecurrentOp : public RecurrentBase {
auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
// Explicit copy output since the local RNN scope can be destroyed // Explicit copy output since the local RNN scope can be destroyed
// early. // early.
framework::Copy(src_tensor, place, dev_ctx, &dst_out); framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
}); });
scopes.Next(); scopes.Next();
...@@ -378,7 +378,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -378,7 +378,7 @@ class RecurrentGradOp : public RecurrentBase {
auto *cur_grad_var = cur_scope.Var(cur_grad); auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor = auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>(); cur_grad_var->GetMutable<framework::LoDTensor>();
framework::Copy(ex_tensor, place, dev_ctx, cur_grad_tensor); framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor);
} }
} }
...@@ -452,7 +452,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -452,7 +452,7 @@ class RecurrentGradOp : public RecurrentBase {
} }
auto dst = outside->Slice(seq_offset, seq_offset + 1); auto dst = outside->Slice(seq_offset, seq_offset + 1);
framework::Copy(inside, place, dev_ctx, &dst); framework::TensorCopy(inside, place, dev_ctx, &dst);
}); });
VLOG(5) << "Link outside gradient finished "; VLOG(5) << "Link outside gradient finished ";
...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase {
framework::LoDTensor *outside) { framework::LoDTensor *outside) {
outside->Resize(inside.dims()); outside->Resize(inside.dims());
outside->mutable_data(place, inside.type()); outside->mutable_data(place, inside.type());
framework::Copy(inside, place, dev_ctx, outside); framework::TensorCopy(inside, place, dev_ctx, outside);
}); });
VLOG(5) << "Link initialize state gradient finished "; VLOG(5) << "Link initialize state gradient finished ";
} }
......
...@@ -170,7 +170,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -170,7 +170,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::Copy(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); framework::TensorCopy(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len; out_offset += len;
return out_offset; return out_offset;
} }
......
...@@ -28,7 +28,7 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -28,7 +28,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto out_dims = out->dims(); auto out_dims = out->dims();
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Copy(*in, ctx.GetPlace(), ctx.device_context(), out); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims); out->Resize(out_dims);
} }
}; };
...@@ -42,7 +42,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
framework::Copy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
}; };
......
...@@ -61,7 +61,7 @@ class SequenceReshapeKernel : public framework::OpKernel<T> { ...@@ -61,7 +61,7 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
} }
} }
framework::Copy(*in, context.GetPlace(), out); framework::TensorCopy(*in, context.GetPlace(), out);
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width}); out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
} }
}; };
...@@ -77,7 +77,7 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> { ...@@ -77,7 +77,7 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> {
context.Output<LoDTensor>(framework::GradVarName("X")); context.Output<LoDTensor>(framework::GradVarName("X"));
xg_tensor_ptr->mutable_data<T>(context.GetPlace()); xg_tensor_ptr->mutable_data<T>(context.GetPlace());
framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); framework::TensorCopy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr);
xg_tensor_ptr->Resize(x_tensor_ptr->dims()); xg_tensor_ptr->Resize(x_tensor_ptr->dims());
} }
}; };
......
...@@ -66,12 +66,12 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> { ...@@ -66,12 +66,12 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::Copy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu); &offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::Copy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu); &length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
...@@ -127,12 +127,12 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> { ...@@ -127,12 +127,12 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::Copy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu); &offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::Copy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu); &length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
......
...@@ -133,7 +133,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -133,7 +133,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto &dout_tensor = dout_var->Get<framework::LoDTensor>(); auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
auto height = dout_tensor.dims()[0]; auto height = dout_tensor.dims()[0];
auto slice = dx_tensor.Slice(0, static_cast<int>(height)); auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::Copy(dout_tensor, dout_tensor.place(), dev_ctx, &slice); framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] > height) { if (dx_tensor.dims()[0] > height) {
auto rest_tensor = dx_tensor.Slice( auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0])); static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
......
...@@ -55,7 +55,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -55,7 +55,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
cpu_mask->ShareDataWith(mask); cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) { } else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
framework::Copy(mask, platform::CPUPlace(), dev_ctx, cpu_mask.get()); framework::TensorCopy(mask, platform::CPUPlace(), dev_ctx,
cpu_mask.get());
#else #else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif #endif
...@@ -113,7 +114,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -113,7 +114,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
// out[offset: offset+len] = x[each_range.begin: each_range.end] // out[offset: offset+len] = x[each_range.begin: each_range.end]
auto slice = out->Slice(static_cast<int>(offset), auto slice = out->Slice(static_cast<int>(offset),
static_cast<int>(offset + len)); static_cast<int>(offset + len));
framework::Copy(x.Slice(static_cast<int>(each_range.begin), framework::TensorCopy(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)), static_cast<int>(each_range.end)),
x.place(), dev_ctx, &slice); x.place(), dev_ctx, &slice);
offset += len; offset += len;
......
...@@ -137,7 +137,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -137,7 +137,7 @@ class SumKernel : public framework::OpKernel<T> {
out_array.resize(i + 1); out_array.resize(i + 1);
} }
if (out_array[i].numel() == 0) { if (out_array[i].numel() == 0) {
framework::Copy(in_array[i], in_array[i].place(), framework::TensorCopy(in_array[i], in_array[i].place(),
context.device_context(), &out_array[i]); context.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod()); out_array[i].set_lod(in_array[i].lod());
} else { } else {
......
...@@ -45,7 +45,7 @@ class WriteToArrayOp : public ArrayOp { ...@@ -45,7 +45,7 @@ class WriteToArrayOp : public ArrayOp {
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
Copy(x_tensor, place, dev_ctx, out_tensor); TensorCopy(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod()); out_tensor->set_lod(x_tensor.lod());
} else { } else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
...@@ -138,7 +138,7 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -138,7 +138,7 @@ class ReadFromArrayOp : public ArrayOp {
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::Copy(x_array[offset], place, dev_ctx, out_tensor); framework::TensorCopy(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} else { } else {
VLOG(10) << "offset " << offset << " >= " << x_array.size(); VLOG(10) << "offset " << offset << " >= " << x_array.size();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
AddComment(R"DOC(
Uniform random operator
This operator initializes a tensor with the same batch_size as the Input tensor
with random values sampled from a uniform distribution.
)DOC");
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximun value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
...@@ -96,7 +96,9 @@ uniform distribution. ...@@ -96,7 +96,9 @@ uniform distribution.
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
"Random seed used for generating samples. " "Random seed used for generating samples. "
"0 means use a seed generated by the system.") "0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
...@@ -110,3 +112,6 @@ REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, ...@@ -110,3 +112,6 @@ REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>, paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>); paddle::operators::CPUUniformRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(uniform_random_batch_size_like,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
...@@ -66,3 +66,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -66,3 +66,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL(uniform_random, REGISTER_OP_CUDA_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>, paddle::operators::GPUUniformRandomKernel<float>,
paddle::operators::GPUUniformRandomKernel<double>); paddle::operators::GPUUniformRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like,
paddle::operators::GPUUniformRandomKernel<float>,
paddle::operators::GPUUniformRandomKernel<double>);
...@@ -185,7 +185,8 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -185,7 +185,8 @@ class WarpCTCKernel : public framework::OpKernel<T> {
// warpctc accesses labels in CPU memory // warpctc accesses labels in CPU memory
Tensor warpctc_label; Tensor warpctc_label;
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); TensorCopy(*label, platform::CPUPlace(), ctx.device_context(),
&warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>(); const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory // warpctc stores loss in CPU memory
Tensor warpctc_loss; Tensor warpctc_loss;
...@@ -200,7 +201,7 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -200,7 +201,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
sequence_width, num_sequences, blank, warpctc_loss_data); sequence_width, num_sequences, blank, warpctc_loss_data);
// Copy the loss back // Copy the loss back
Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss); TensorCopy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss);
} }
}; };
......
...@@ -101,7 +101,7 @@ T TensorGetElement(framework::Tensor &self, size_t offset) { ...@@ -101,7 +101,7 @@ T TensorGetElement(framework::Tensor &self, size_t offset) {
return self.data<T>()[offset]; return self.data<T>()[offset];
} else { } else {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor); std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::Copy(self, platform::CPUPlace(), dst.get()); framework::TensorCopy(self, platform::CPUPlace(), dst.get());
return dst->data<T>()[offset]; return dst->data<T>()[offset];
} }
} }
...@@ -111,9 +111,9 @@ template <typename T> ...@@ -111,9 +111,9 @@ template <typename T>
void TensorSetElement(framework::Tensor &self, size_t offset, T elem) { void TensorSetElement(framework::Tensor &self, size_t offset, T elem) {
if (platform::is_gpu_place(self.place())) { if (platform::is_gpu_place(self.place())) {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor); std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::Copy(self, platform::CPUPlace(), dst.get()); framework::TensorCopy(self, platform::CPUPlace(), dst.get());
dst->data<T>()[offset] = elem; dst->data<T>()[offset] = elem;
framework::Copy(*dst.get(), self.place(), &self); framework::TensorCopy(*dst.get(), self.place(), &self);
} else if (platform::is_cpu_place(self.place())) { } else if (platform::is_cpu_place(self.place())) {
self.data<T>()[offset] = elem; self.data<T>()[offset] = elem;
......
...@@ -68,6 +68,7 @@ __all__ = [ ...@@ -68,6 +68,7 @@ __all__ = [
'layer_norm', 'layer_norm',
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'smooth_l1', 'smooth_l1',
'one_hot',
] ]
...@@ -3212,3 +3213,40 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -3212,3 +3213,40 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
'Out': loss}, 'Out': loss},
attrs={'sigma': sigma}) attrs={'sigma': sigma})
return loss return loss
def one_hot(input, depth):
"""
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator.
Args:
input(Tensor/LodTensor): A Tensor/LodTensor of indices, last dimension must be 1.
depth(scalar): an interger defining the depth of the one hot dimension.
Returns:
The one-hot tensor or LodTensor, same as input.
Examples:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
"""
helper = LayerHelper("one_hot", **locals())
one_hot_out = helper.create_tmp_variable(dtype='float32')
helper.append_op(
type="one_hot",
inputs={'X': input},
attrs={'depth': depth},
outputs={'Out': one_hot_out})
return one_hot_out
...@@ -66,6 +66,9 @@ __all__ = [ ...@@ -66,6 +66,9 @@ __all__ = [
'logical_xor', 'logical_xor',
'logical_not', 'logical_not',
'uniform_random', 'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum', 'cumsum',
] + __activations__ ] + __activations__
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import unittest import unittest
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
...@@ -23,7 +24,8 @@ import sys ...@@ -23,7 +24,8 @@ import sys
def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
hid_dim=32): hid_dim=32):
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True)
conv_3 = fluid.nets.sequence_conv_pool( conv_3 = fluid.nets.sequence_conv_pool(
input=emb, input=emb,
num_filters=hid_dim, num_filters=hid_dim,
...@@ -41,8 +43,6 @@ def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, ...@@ -41,8 +43,6 @@ def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
act="softmax") act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002)
adam_optimizer.minimize(avg_cost)
accuracy = fluid.layers.accuracy(input=prediction, label=label) accuracy = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, accuracy, prediction return avg_cost, accuracy, prediction
...@@ -56,7 +56,8 @@ def stacked_lstm_net(data, ...@@ -56,7 +56,8 @@ def stacked_lstm_net(data,
stacked_num=3): stacked_num=3):
assert stacked_num % 2 == 1 assert stacked_num % 2 == 1
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True)
# add bias attr # add bias attr
# TODO(qijun) linear act # TODO(qijun) linear act
...@@ -79,8 +80,6 @@ def stacked_lstm_net(data, ...@@ -79,8 +80,6 @@ def stacked_lstm_net(data,
act='softmax') act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002)
adam_optimizer.minimize(avg_cost)
accuracy = fluid.layers.accuracy(input=prediction, label=label) accuracy = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, accuracy, prediction return avg_cost, accuracy, prediction
...@@ -93,7 +92,7 @@ def create_random_lodtensor(lod, place, low, high): ...@@ -93,7 +92,7 @@ def create_random_lodtensor(lod, place, low, high):
return res return res
def train(word_dict, net_method, use_cuda, save_dirname=None): def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
BATCH_SIZE = 128 BATCH_SIZE = 128
PASS_NUM = 5 PASS_NUM = 5
dict_dim = len(word_dict) dict_dim = len(word_dict)
...@@ -102,8 +101,30 @@ def train(word_dict, net_method, use_cuda, save_dirname=None): ...@@ -102,8 +101,30 @@ def train(word_dict, net_method, use_cuda, save_dirname=None):
data = fluid.layers.data( data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
if not parallel:
cost, acc_out, prediction = net_method( cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim) data, label, input_dim=dict_dim, class_dim=class_dim)
else:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
cost, acc, _ = net_method(
pd.read_input(data),
pd.read_input(label),
input_dim=dict_dim,
class_dim=class_dim)
pd.write_output(cost)
pd.write_output(acc)
cost, acc = pd()
cost = fluid.layers.mean(x=cost)
acc_out = fluid.layers.mean(x=acc)
prediction = None
assert save_dirname is None
adagrad = fluid.optimizer.Adagrad(learning_rate=0.002)
adagrad.minimize(cost)
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -164,14 +185,16 @@ def infer(use_cuda, save_dirname=None): ...@@ -164,14 +185,16 @@ def infer(use_cuda, save_dirname=None):
print("Inference results: ", np_data) print("Inference results: ", np_data)
def main(word_dict, net_method, use_cuda): def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
# Directory for saving the trained model train(
save_dirname = "understand_sentiment.inference.model" word_dict,
net_method,
train(word_dict, net_method, use_cuda, save_dirname) use_cuda,
parallel=parallel,
save_dirname=save_dirname)
infer(use_cuda, save_dirname) infer(use_cuda, save_dirname)
...@@ -191,20 +214,62 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -191,20 +214,62 @@ class TestUnderstandSentiment(unittest.TestCase):
def test_conv_cpu(self): def test_conv_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=False) main(
self.word_dict,
net_method=convolution_net,
use_cuda=False,
save_dirname="understand_sentiment.inference.model")
def test_conv_cpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=convolution_net,
use_cuda=False,
parallel=True)
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_cpu(self): def test_stacked_lstm_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False) main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False)
def test_stacked_lstm_cpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=False,
parallel=True)
def test_conv_gpu(self): def test_conv_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=True) main(
self.word_dict,
net_method=convolution_net,
use_cuda=True,
save_dirname="understand_sentiment.inference.model")
def test_conv_gpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=convolution_net,
use_cuda=True,
parallel=True)
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_gpu(self): def test_stacked_lstm_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True) main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True)
def test_stacked_lstm_gpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=True,
parallel=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -248,7 +248,11 @@ class OpTest(unittest.TestCase): ...@@ -248,7 +248,11 @@ class OpTest(unittest.TestCase):
return feed_map return feed_map
def check_output_with_place(self, place, atol): def calc_output(self, place):
outs, _ = self._calc_output(place)
return outs
def _calc_output(self, place):
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
program = Program() program = Program()
...@@ -281,7 +285,10 @@ class OpTest(unittest.TestCase): ...@@ -281,7 +285,10 @@ class OpTest(unittest.TestCase):
feed=feed_map, feed=feed_map,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
return outs, fetch_list
def check_output_with_place(self, place, atol):
outs, fetch_list = self._calc_output(place)
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs: if out_name not in self.outputs:
continue continue
...@@ -340,6 +347,15 @@ class OpTest(unittest.TestCase): ...@@ -340,6 +347,15 @@ class OpTest(unittest.TestCase):
for place in places: for place in places:
self.check_output_with_place(place, atol) self.check_output_with_place(place, atol)
def check_output_customized(self, checker):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
for place in places:
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
checker(outs)
def __assert_is_close(self, numeric_grads, analytic_grads, names, def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix): max_relative_error, msg_prefix):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestGaussianRandomBatchSizeLike(OpTest):
def setUp(self):
self.op_type = "gaussian_random_batch_size_like"
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
self.attrs = {'mean': 1., 'std': 2., 'shape': [-1, 2000]}
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
self.assertEqual(outs[0].shape, (500, 2000))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float32")
hist /= float(outs[0].size)
data = np.random.normal(size=(500, 2000), loc=1, scale=2)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float32")
hist2 /= float(outs[0].size)
self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.01),
"hist: " + str(hist) + " hist2: " + str(hist2))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestUniformRandomBatchSizeLike(OpTest):
def setUp(self):
self.op_type = "uniform_random_batch_size_like"
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
self.attrs = {'min': 1., 'max': 2., 'shape': [-1, 2000]}
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
self.assertEqual(outs[0].shape, (500, 2000))
hist, _ = np.histogram(outs[0], range=(1, 2))
hist = hist.astype("float32")
hist /= float(outs[0].size)
prob = 0.1 * np.ones((10))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
if __name__ == "__main__":
unittest.main()
...@@ -13,14 +13,11 @@ ...@@ -13,14 +13,11 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import numpy import numpy as np
from op_test import OpTest
from paddle.v2.fluid.op import Operator
import paddle.v2.fluid.core as core
import paddle.v2.fluid as fluid
class TestUniformRandomOp(OpTest):
class TestUniformRandomOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.op_type = "uniform_random" self.op_type = "uniform_random"
self.inputs = {} self.inputs = {}
...@@ -30,35 +27,20 @@ class TestUniformRandomOp(unittest.TestCase): ...@@ -30,35 +27,20 @@ class TestUniformRandomOp(unittest.TestCase):
"max": 10.0, "max": 10.0,
"seed": 10 "seed": 10
} }
self.outputs = ["Out"] self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def test_cpu(self):
self.uniform_random_test(place=core.CPUPlace())
def test_gpu(self):
if core.is_compiled_with_cuda():
self.uniform_random_test(place=core.CUDAPlace(0))
def uniform_random_test(self, place):
program = fluid.Program()
block = program.global_block()
vout = block.create_var(name="Out")
op = block.append_op(
type=self.op_type, outputs={"Out": vout}, attrs=self.attrs)
op.desc.infer_var_type(block.desc) def test_check_output(self):
op.desc.infer_shape(block.desc) self.check_output_customized(self.verify_output)
fetch_list = []
for var_name in self.outputs:
fetch_list.append(block.var(var_name))
exe = fluid.Executor(place)
outs = exe.run(program, fetch_list=fetch_list)
def verify_output(self, outs):
tensor = outs[0] tensor = outs[0]
hist, _ = np.histogram(outs[0], range=(-5, 10))
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1) hist = hist.astype("float32")
hist /= float(outs[0].size)
prob = 0.1 * np.ones((10))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册