diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index ef5c4bcd73d48578ef0258efcd68c6363549afd9..2eb5cc55cd6ac949fb3a720f414d18300a8ea227 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -1,5 +1,8 @@ +add_subdirectory(accumulation) add_subdirectory(api) add_subdirectory(tests) + cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) +cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation) cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta) diff --git a/paddle/fluid/eager/accumulation/CMakeLists.txt b/paddle/fluid/eager/accumulation/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bfc7b54bef1567dcc2a0e2e181288aed430dbe73 --- /dev/null +++ b/paddle/fluid/eager/accumulation/CMakeLists.txt @@ -0,0 +1,2 @@ +cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function) +cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc new file mode 100644 index 0000000000000000000000000000000000000000..69628d9b40021d092fcb65add7088e7df7fcd18e --- /dev/null +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/accumulation/gradient_accumulation.h" +#include "paddle/fluid/eager/eager_tensor.h" + +#include "paddle/pten/api/all.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/include/core.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +#include "glog/logging.h" + +static void CopyOrAddTensor(egr::EagerTensor* tensor, + const egr::EagerTensor& t) { + if (!tensor->defined() || !tensor->initialized()) { + // Simply copy tensor->impl + *tensor = t; + } else { + // Accumulation + egr::TensorAdd(t, tensor); + } +} + +namespace egr { + +void GradNodeAccumulation::RetainGrad( + const std::function& hook) { + retain_grad_hook_ = hook; +} + +std::vector> GradNodeAccumulation::operator()( + const std::vector>& grads) { + PADDLE_ENFORCE(grads.size() == 1, + paddle::platform::errors::Fatal( + "GradNodeAccumulation should take exactly 1 grad tensor" + "However received: %d slot.", + grads.size())); + PADDLE_ENFORCE(grads[0].size() == 1, + paddle::platform::errors::Fatal( + "GradNodeAccumulation should take exactly 1 grad tensor" + "However received: %d in slot %d .", + grads[0].size(), 0)); + // Apply Gradient Hooks + if (GradientHooksRegistered()) { + std::vector> hooked_grads = + ApplyGradientHooks(grads); + // TODO(jiabin): It's little weird + CopyOrAddTensor(&accumulated_grad, hooked_grads[0][0]); + } else { + CopyOrAddTensor(&accumulated_grad, grads[0][0]); + } + + if (retain_grad_hook_ != nullptr) { + retain_grad_hook_(accumulated_grad); + } + + // Apply Reduce Hooks + if (ReduceHooksRegistered()) { + ApplyReduceHooks(); + } + + return {{accumulated_grad}}; +} + +} // namespace egr diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h new file mode 100644 index 0000000000000000000000000000000000000000..2582cd3c9df8ec507878c957500e507399b8c953 --- /dev/null +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -0,0 +1,41 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/grad_node_info.h" + +namespace egr { + +class GradNodeAccumulation : public GradNodeBase { + public: + // Constructor: configure fwd input tensors to grad node + GradNodeAccumulation() : GradNodeBase(1, 1) { SetDefaultGradInOutMeta(); } + + ~GradNodeAccumulation() override = default; + + // Functor: perform backward computations + virtual std::vector> operator()( + const std::vector>& grads) override; + + void RetainGrad( + const std::function& hook); + + private: + egr::EagerTensor accumulated_grad; + + std::function retain_grad_hook_; +}; + +} // namespace egr diff --git a/paddle/fluid/eager/accumulation/gradient_accumulation.cc b/paddle/fluid/eager/accumulation/gradient_accumulation.cc new file mode 100644 index 0000000000000000000000000000000000000000..13fe394b5de4e5344603fac85b4718304b66e1ec --- /dev/null +++ b/paddle/fluid/eager/accumulation/gradient_accumulation.cc @@ -0,0 +1,304 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/accumulation/gradient_accumulation.h" +#include +#include +#include +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/math_function_impl.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/pten/api/all.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/include/core.h" +#include "unsupported/Eigen/CXX11/Tensor" +#ifdef PADDLE_WITH_XPU +#include "xpu/refactor/math.h" +#endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/operators/npu_op_runner.h" +#endif + +namespace egr { +template +class TensorAddFunctor : public boost::static_visitor<> { + public: + TensorAddFunctor(int64_t numel, const T* x, T* y) + : numel_(numel), x_(x), y_(y) {} + + void operator()(const paddle::platform::CPUPlace& place) { + paddle::platform::CPUDeviceContext* ctx = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)); + auto blas = + paddle::operators::math::GetBlas( + *ctx); + blas.AXPY(numel_, 1., x_, y_); + } + +// TODO(jiabin): Support xpu here from gradient_accumulator.cc + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + void operator()(const paddle::platform::CUDAPlace& place) { + paddle::platform::CUDADeviceContext* ctx = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)); + auto blas = + paddle::operators::math::GetBlas(*ctx); + blas.AXPY(numel_, 1., x_, y_); + } +#else + void operator()(const paddle::platform::CUDAPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + + // TODO(jiabin): Support Npu here from gradient_accumulator.cc + // there is NO blas in CUDAPinnedPlace + void operator()(const paddle::platform::CUDAPinnedPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } + +#ifdef PADDLE_WITH_ASCEND_CL + void operator()(const paddle::platform::NPUPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#else + void operator()(const paddle::platform::NPUPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + +#ifdef PADDLE_WITH_XPU + void operator()(const paddle::platform::XPUPlace& place) { + paddle::platform::XPUDeviceContext* ctx = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)); + xpu::add(ctx->x_context(), x_, y_, y_, static_cast(numel_)); + } +#else + void operator()(const paddle::platform::XPUPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + + void operator()(const paddle::platform::NPUPinnedPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } + + private: + int64_t numel_; + const T* x_; + T* y_; +}; + +template +void TensorAddImpl(const std::shared_ptr& src, + pten::DenseTensor* dst, + const paddle::platform::Place& place) { + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + paddle::platform::DeviceContext* ctx = pool.Get(place); + auto dev_ctx = dynamic_cast(ctx); + paddle::operators::math::ElementwiseAddTo func; + func(dev_ctx, *(src.get()), dst); +} + +template +void TensorAddImpl(const paddle::framework::Tensor& src, + paddle::framework::Tensor* dst, + const paddle::platform::Place& place) { + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + paddle::platform::DeviceContext* ctx = pool.Get(place); + auto dev_ctx = dynamic_cast(ctx); + paddle::operators::math::ElementwiseAddTo func; + func(dev_ctx, src, dst); +} + +void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) { + // TODO(jiabin): Support other tensor type later + std::shared_ptr dst_tensor = + std::dynamic_pointer_cast(dst->impl()); + std::shared_ptr src_tensor = + std::dynamic_pointer_cast(src.impl()); + + auto numel = src_tensor->numel(); + + if (numel == 0) { + return; + } + + PADDLE_ENFORCE_EQ( + dst_tensor->numel(), numel, + paddle::platform::errors::PreconditionNotMet( + "The number of elements of source tensor and destination tensor " + "should be equal, but got the number of elements of source tensor is " + "%zu and the number of elements of destination tensor is %zu.", + numel, dst_tensor->numel())); + + auto data_type = pten::TransToProtoVarType(src_tensor->dtype()); + auto place = src_tensor->place(); + + PADDLE_ENFORCE_EQ(pten::TransToProtoVarType(dst_tensor->dtype()), data_type, + paddle::platform::errors::PreconditionNotMet( + "The data type of source tensor and destination tensor " + "should be equal, Otherwise, the calculation results " + "will be incorrect.")); + +#define PADDLE_TENSOR_ADD(cpp_type) \ + if (data_type == paddle::framework::DataTypeTrait::DataType()) { \ + TensorAddFunctor func(numel, src_tensor->data(), \ + dst_tensor->mutable_data()); \ + boost::apply_visitor(func, place); \ + return; \ + } + + // TODO(jiabin): Support NPU here + PADDLE_TENSOR_ADD(float); + // NOTE(phlrain): xpu only support float + PADDLE_TENSOR_ADD(double); + // NOTE(chenweihang): only support complex grad tensor accumulated, + // support selected rows if needed in the future + PADDLE_TENSOR_ADD(paddle::platform::complex); + PADDLE_TENSOR_ADD(paddle::platform::complex); + +#undef PADDLE_TENSOR_ADD + + if (data_type == paddle::framework::proto::VarType::FP16) { + if (paddle::platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + return TensorAddImpl(src_tensor, + dst_tensor.get(), place); +#else + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + paddle::framework::DataTypeToString(data_type), place)); +#endif + } else if (paddle::platform::is_cpu_place(place)) { + return TensorAddImpl(src_tensor, + dst_tensor.get(), place); + } + } + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + paddle::framework::DataTypeToString(data_type), place)); +} + +void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) { + // TODO(jiabin): Support other tensor type later + auto* dst_tensor = + dst->MutableVar()->GetMutable(); + auto& src_tensor = src.Var().Get(); + + auto numel = src_tensor.numel(); + + // FIXME(minqiyang): loss_grad op will pass a zero grad of label + // ugly fix for it + if (numel == 0) { + return; + } + + PADDLE_ENFORCE_EQ( + dst_tensor->numel(), numel, + paddle::platform::errors::PreconditionNotMet( + "The number of elements of source tensor and destination tensor " + "should be equal, but got the number of elements of source tensor is " + "%zu and the number of elements of destination tensor is %zu.", + numel, dst_tensor->numel())); + + auto data_type = src_tensor.type(); + auto place = src_tensor.place(); + + PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type, + paddle::platform::errors::PreconditionNotMet( + "The data type of source tensor and destination tensor " + "should be equal, Otherwise, the calculation results " + "will be incorrect.")); + +#define PADDLE_TENSOR_ADD(cpp_type) \ + if (data_type == paddle::framework::DataTypeTrait::DataType()) { \ + TensorAddFunctor func( \ + numel, src_tensor.data(), \ + dst_tensor->mutable_data(place)); \ + boost::apply_visitor(func, place); \ + return; \ + } + + // TODO(jiabin): Support NPU here + PADDLE_TENSOR_ADD(float); + // NOTE(phlrain): xpu only support float + PADDLE_TENSOR_ADD(double); + // NOTE(chenweihang): only support complex grad tensor accumulated, + // support selected rows if needed in the future + PADDLE_TENSOR_ADD(paddle::platform::complex); + PADDLE_TENSOR_ADD(paddle::platform::complex); + +#undef PADDLE_TENSOR_ADD + + if (data_type == paddle::framework::proto::VarType::FP16) { + if (paddle::platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + return TensorAddImpl(src_tensor, dst_tensor, + place); +#else + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + paddle::framework::DataTypeToString(data_type), place)); +#endif + } else if (paddle::platform::is_cpu_place(place)) { + return TensorAddImpl(src_tensor, dst_tensor, + place); + } + } + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + paddle::framework::DataTypeToString(data_type), place)); +} + +} // namespace egr diff --git a/paddle/fluid/eager/accumulation/gradient_accumulation.h b/paddle/fluid/eager/accumulation/gradient_accumulation.h new file mode 100644 index 0000000000000000000000000000000000000000..725410dac729e6421cbcd6a4d104f7d271404cff --- /dev/null +++ b/paddle/fluid/eager/accumulation/gradient_accumulation.h @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/pten/api/all.h" +namespace egr { +// Accumulation API +void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst); +void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst); + +} // namespace egr diff --git a/paddle/fluid/eager/eager_tensor.h b/paddle/fluid/eager/eager_tensor.h index 753040a2623f9ba2b79c3587af179fb1a2787146..7ade0a9848dc41be9d7b1da49b3808c5a4abeb0d 100644 --- a/paddle/fluid/eager/eager_tensor.h +++ b/paddle/fluid/eager/eager_tensor.h @@ -248,6 +248,14 @@ class EagerTensor final { void ResetVar(const paddle::framework::Variable& src) { var_ = src; } + const std::shared_ptr& Tensor() const { + return tensor_; + } + + void set_tensor(const std::shared_ptr& tensor) { + tensor_ = tensor; + } + private: template void SetImplWithLegacyTensor() { diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0344e20fb9bbdedaf37aa91809011d2fd0b6276 --- /dev/null +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/grad_tensor_holder.h" +#include "paddle/fluid/eager/accumulation/gradient_accumulation.h" + +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace egr { + +static void FillUnderlyingVariableWithValue( + double value, const paddle::framework::DDim& ddim, + const paddle::platform::Place& place, + const paddle::framework::proto::VarType::Type& dtype, + egr::EagerTensor* target) { + auto* dst_tensor = + target->MutableVar()->GetMutable(); + auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + dst_tensor->Resize(ddim); + // TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in + // grad mission + // we can't get data_type_ directly. We need to check if we can only use + // default data_type for now. + dst_tensor->mutable_data(place, dtype); + paddle::operators::math::set_constant(*dev_ctx, dst_tensor, value); +} + +void GradTensorHolder::add(size_t slot_id, size_t rank, + const egr::EagerTensor& t, bool fill_one) { + // TODO(jiabin): We need to deal with empty input_buffer with slot size not + // empty; + PADDLE_ENFORCE(slot_id < buffer_.size(), + paddle::platform::errors::Fatal( + "Invalid slot_id for GradTensorHolder::add() " + "which exceeds size of buffer")); + VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id + << ", size: " << buffer_[slot_id].size(); + if (buffer_[slot_id].empty()) { + VLOG(6) << "Pass add Tensor for buffer_ slot: " << slot_id + << " since its buffer_ is empty "; + return; + } + PADDLE_ENFORCE( + rank < buffer_[slot_id].size(), + paddle::platform::errors::Fatal( + "Invalid rank for GradTensorHolder::add() which exceeds size " + "of buffer slot %d, got slot size is: %d rank is: %d", + slot_id, buffer_[slot_id].size(), rank)); + egr::EagerTensor& buffer_tensor = buffer_[slot_id][rank]; + if (!fill_one) { + // TODO(jiabin): Code bellow is ugly to divide which inner var we used, + // remove framework::Variable + // related code later. + // This if statement is trying to test neither pten::Tensor nor + // framework::Variable is initialized. + if ((!buffer_tensor.defined() || !buffer_tensor.initialized()) && + (!buffer_tensor.Var().IsInitialized())) { + // Simply copy tensor->impl + buffer_tensor = t; + } else { + // Accumulation + if (t.initialized() && buffer_tensor.initialized()) { + TensorAdd(t, &buffer_tensor); + } else if (t.Var().IsInitialized() && + buffer_tensor.Var().IsInitialized()) { + VariableAdd(t, &buffer_tensor); + } else if (t.Var().IsInitialized() && buffer_tensor.initialized()) { + // TODO(jiabin): This can be merge to upper if case. + buffer_tensor.SyncToVar(); + VariableAdd(t, &buffer_tensor); + } else if (t.initialized() && buffer_tensor.Var().IsInitialized()) { + buffer_tensor.SyncToTensor(); + TensorAdd(t, &buffer_tensor); + } else { + // Should not happend case + // 1. both not init + } + } + } else { + // Create new tensor->impl and fill it with 1.0 + if (t.defined()) { + // Fill 1.0 + paddle::experimental::Tensor tensor = + paddle::experimental::ones_like(*t.Tensor().get()); + buffer_tensor.set_tensor( + std::make_shared(tensor)); + + } else { + // TODO(jiabin): Only Support LodTensorForNow + auto type = paddle::framework::ToVarType(t.Var().Type()); + switch (type) { + case paddle::framework::proto::VarType::LOD_TENSOR: { + auto t_ftensor = t.Var().Get(); + FillUnderlyingVariableWithValue(1.0, t_ftensor.dims(), + t_ftensor.place(), t_ftensor.type(), + &buffer_tensor); + break; + } + default: { + PADDLE_THROW(paddle::platform::errors::NotFound( + "Cannot found var type: %s in Fill Constant API", + paddle::framework::ToTypeName(type))); + } + } + } + } +} + +} // namespace egr diff --git a/paddle/fluid/eager/grad_tensor_holder.h b/paddle/fluid/eager/grad_tensor_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..5072447fa9343ea5c5394672912b665bed2f3ad1 --- /dev/null +++ b/paddle/fluid/eager/grad_tensor_holder.h @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/grad_node_info.h" + +namespace egr { + +/** + * Input Buffer is designed for backward grad accumulate. + * Since we will have one output used by multi preceding ops in forward pass, + * we will meet a problem that we need to accumulate multiple grads into one. + * + * GradTensorHolder should have as same format as forward output **/ +class GradTensorHolder { + public: + explicit GradTensorHolder(const std::vector& meta) { + VLOG(7) << "Init GradTensorHolder with meta size: " << meta.size(); + buffer_.resize(meta.size()); + for (size_t i = 0; i < buffer_.size(); i++) { + VLOG(7) << "Init GradTensorHolder with meta rank: " << meta[i].Size(); + buffer_[i].resize(meta[i].Size()); + } + } + + GradTensorHolder(const GradTensorHolder& other) = default; + + explicit GradTensorHolder(std::vector>&& inputs) + : buffer_(std::move(inputs)) {} + + GradTensorHolder& operator=(const GradTensorHolder& other) = default; + + // Create new tensor and copy tensor->impl + void add(size_t slot_id, size_t rank, const egr::EagerTensor& t, + bool fill_one = false); + + const std::vector& operator[](const size_t& pos) { + return buffer_[pos]; + } + + const std::vector>& Buffers() { + return buffer_; + } + + private: + std::vector> buffer_; +}; + +} // namespace egr diff --git a/paddle/fluid/eager/tests/CMakeLists.txt b/paddle/fluid/eager/tests/CMakeLists.txt index 35b9934141b4dfb3b4ad787088d715eb31653a74..fb01fcf91a94db342f0bfa9cbfb189d7307ceeab 100644 --- a/paddle/fluid/eager/tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/CMakeLists.txt @@ -1,3 +1,3 @@ -set(eager_deps pten pten_api pten_tensor utils global_utils autograd_meta grad_node_info) +set(eager_deps pten pten_api utils global_utils pten_tensor autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node) add_subdirectory(data_structure_tests) add_subdirectory(task_tests) diff --git a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt index 2989330efa8aac1f3a6b7cec0e3600cd0f999318..2b06687db1af70b655cde45f98c58307647b7412 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt @@ -1,4 +1,6 @@ cc_test(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS ${eager_deps} ) -cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps} grad_node_info) -cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps} grad_node_info) -cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps} grad_node_info utils) +cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps}) +cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps}) +cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps}) +cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps}) +cc_test(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS ${eager_deps}) diff --git a/paddle/fluid/eager/tests/data_structure_tests/accumulation_node_test.cc b/paddle/fluid/eager/tests/data_structure_tests/accumulation_node_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..20601d0c5811342b93755a8204d24d8b9268384a --- /dev/null +++ b/paddle/fluid/eager/tests/data_structure_tests/accumulation_node_test.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/grad_tensor_holder.h" +#include "paddle/pten/api/lib/utils/allocator.h" + +#include "paddle/pten/core/kernel_registry.h" + +// TODO(jiabin): remove nolint here!!! +using namespace egr; // NOLINT + +TEST(AccumulationNode, EagerTensor) { + // Construct Eager Tensor + pten::DenseTensorMeta meta = pten::DenseTensorMeta( + pten::DataType::FLOAT16, paddle::framework::make_ddim({1, 1})); + std::shared_ptr dt0 = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + dt0->mutable_data()[0] = 10.0; + EagerTensor et0 = EagerTensor(dt0); + + std::shared_ptr dt1 = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + dt1->mutable_data()[0] = 20.0; + EagerTensor et1 = EagerTensor(dt1); + + std::shared_ptr grad_dt = + std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + EagerTensor grad_et = EagerTensor(grad_dt); + + // AccumulationNode + GradNodeAccumulation node = GradNodeAccumulation(); + + // Hook + std::function hook = + [&grad_et](const egr::EagerTensor& t) { + if (t.defined()) { + grad_et.set_impl(t.impl()); + return grad_et; + } else { + grad_et.MutableVar() + ->GetMutable() + ->ShareDataWith(t.Var().Get()); + return grad_et; + } + }; + node.RetainGrad(hook); + + // operator() + EagerTensor ret_et0 = node({{et0}})[0][0]; + auto* ret_et0_ptr = + std::dynamic_pointer_cast(ret_et0.impl()) + ->data(); + CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f)); + + EagerTensor ret_et1 = node({{et1}})[0][0]; + auto* ret_et1_ptr = + std::dynamic_pointer_cast(ret_et1.impl()) + ->data(); + CHECK_EQ(ret_et1_ptr[0], paddle::platform::float16(30.0f)); + + // Retain Grad + auto* ret_grad_et_ptr = + std::dynamic_pointer_cast(grad_et.impl()) + ->data(); + CHECK_EQ(ret_grad_et_ptr[0], paddle::platform::float16(30.0f)); +} diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc b/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3581ef59cd5bee409ef27035954de6319d65a680 --- /dev/null +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_tensor_holder_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/grad_tensor_holder.h" +#include "paddle/pten/api/lib/utils/allocator.h" + +#include "paddle/pten/core/kernel_registry.h" + +// TODO(jiabin): remove nolint here!!! +using namespace egr; // NOLINT + +TEST(GradTensorHolder, Constructor) { + GradSlotMeta slot_meta; + slot_meta.Init(1); + GradTensorHolder grad_tensor_holder = GradTensorHolder({slot_meta}); + GradTensorHolder grad_tensor_holder2 = GradTensorHolder(grad_tensor_holder); + + // Construct Eager Tensor + pten::DenseTensorMeta meta = pten::DenseTensorMeta( + pten::DataType::FLOAT32, paddle::framework::make_ddim({2, 2})); + std::shared_ptr dt = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + EagerTensor et = EagerTensor(dt); + + std::vector> inputs; + inputs.push_back({et}); + + GradTensorHolder grad_tensor_holder4 = GradTensorHolder(std::move(inputs)); +} + +TEST(GradTensorHolder, Interfaces) { + // Construct Eager Tensor + pten::DenseTensorMeta meta = pten::DenseTensorMeta( + pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 1})); + std::shared_ptr dt0 = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + dt0->mutable_data()[0] = 10.0; + EagerTensor et0 = EagerTensor(dt0); + + std::shared_ptr dt1 = std::make_shared( + std::make_shared( + paddle::platform::CPUPlace()), + meta); + dt1->mutable_data()[0] = 20.0; + EagerTensor et1 = EagerTensor(dt1); + + // Constructor empty GradTensorHolder + GradSlotMeta slot_meta; + slot_meta.Init(1); + GradTensorHolder grad_tensor_holder = + GradTensorHolder({slot_meta, slot_meta}); + + // add(): + // fill one + grad_tensor_holder.add(0, 0, et0, true); + + // accumulation + grad_tensor_holder.add(1, 0, et0, false); + grad_tensor_holder.add(1, 0, et1, false); + + // Buffers() + const auto& buffers = grad_tensor_holder.Buffers(); + CHECK_EQ(static_cast(buffers.size()), 2); + CHECK_EQ(static_cast(buffers[0].size()), 1); + CHECK_EQ(static_cast(buffers[1].size()), 1); + + // operator[] + const auto& holder_et0 = grad_tensor_holder[0][0]; + const auto& holder_et1 = grad_tensor_holder[1][0]; + + auto* holder_et0_ptr = + std::dynamic_pointer_cast(holder_et0.impl()) + ->data(); + auto* holder_et1_ptr = + std::dynamic_pointer_cast(holder_et1.impl()) + ->data(); + + CHECK_EQ(holder_et0_ptr[0], 1.0f); + CHECK_EQ(holder_et1_ptr[0], 30.0f); +} diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index c6005bebe18554f9928cd36b1f75a0f2553a89e6..a2f619d84a21e2bca5821734bb8fc03541ea4f8c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -61,7 +61,7 @@ math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) -math_library(math_function DEPS blas) +math_library(math_function DEPS blas dense_tensor tensor) math_library(maxouting) math_library(pooling) diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 9bd833919626d102ef533183693de4cbfb16be6a..cd919b18b83c80d8caa4bbf4e3334e7432dd18f9 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/functions/eigen/common.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -268,6 +269,13 @@ struct ElementwiseAddTo { auto& place = *(ctx->eigen_device()); out.device(place) = out + in; } + void operator()(platform::CPUDeviceContext* ctx, const pten::DenseTensor& src, + pten::DenseTensor* dst) { + auto in = pten::EigenVector::Flatten(src); + auto out = pten::EigenVector::Flatten(*dst); + auto& place = *(ctx->eigen_device()); + out.device(place) = out + in; + } }; template struct ElementwiseAddTo; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index cfdfa456e39eac6f09a7f0d261d228f7f98b2da8..144db9f5a2bb4be5d9383b3893a95f099c6ae00a 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/functions/eigen/common.h" namespace paddle { namespace operators { @@ -283,10 +284,18 @@ struct ElementwiseAddTo { auto& place = *(ctx->eigen_device()); out.device(place) = out + in; } + void operator()(platform::CUDADeviceContext* ctx, + const pten::DenseTensor& src, pten::DenseTensor* dst) { + auto in = pten::EigenVector::Flatten(src); + auto out = pten::EigenVector::Flatten(*dst); + auto& place = *(ctx->eigen_device()); + out.device(place) = out + in; + } }; template struct ElementwiseAddTo; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index ea313cb616916cb8e60f4e44cd7edc15b51063cc..4c0eb592e8c17b603c559867bf77bde766001bac 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/pten/core/dense_tensor.h" namespace paddle { namespace operators { @@ -65,6 +66,8 @@ struct ElementwiseAddTo { // dst = dst + src void operator()(DeviceContext* ctx, const framework::Tensor& src, framework::Tensor* dst); + void operator()(DeviceContext* ctx, const pten::DenseTensor& src, + pten::DenseTensor* dst); }; template