未验证 提交 bc9f9f43 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added GradTensorHolder to Eager Dygraph (#37458)

* Added GradTensorHolder to Eager Dygraph

* Added accumulation codes to Eager Dygraph

* Fix windows-ci issue

* Fix NPU-CI issue

* Fixed CI-Coverage issue
上级 3f815e76
add_subdirectory(accumulation)
add_subdirectory(api)
add_subdirectory(tests)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta)
cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/pten/api/all.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/core.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "glog/logging.h"
static void CopyOrAddTensor(egr::EagerTensor* tensor,
const egr::EagerTensor& t) {
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
*tensor = t;
} else {
// Accumulation
egr::TensorAdd(t, tensor);
}
}
namespace egr {
void GradNodeAccumulation::RetainGrad(
const std::function<egr::EagerTensor(const egr::EagerTensor&)>& hook) {
retain_grad_hook_ = hook;
}
std::vector<std::vector<egr::EagerTensor>> GradNodeAccumulation::operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) {
PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal(
"GradNodeAccumulation should take exactly 1 grad tensor"
"However received: %d slot.",
grads.size()));
PADDLE_ENFORCE(grads[0].size() == 1,
paddle::platform::errors::Fatal(
"GradNodeAccumulation should take exactly 1 grad tensor"
"However received: %d in slot %d .",
grads[0].size(), 0));
// Apply Gradient Hooks
if (GradientHooksRegistered()) {
std::vector<std::vector<egr::EagerTensor>> hooked_grads =
ApplyGradientHooks(grads);
// TODO(jiabin): It's little weird
CopyOrAddTensor(&accumulated_grad, hooked_grads[0][0]);
} else {
CopyOrAddTensor(&accumulated_grad, grads[0][0]);
}
if (retain_grad_hook_ != nullptr) {
retain_grad_hook_(accumulated_grad);
}
// Apply Reduce Hooks
if (ReduceHooksRegistered()) {
ApplyReduceHooks();
}
return {{accumulated_grad}};
}
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/grad_node_info.h"
namespace egr {
class GradNodeAccumulation : public GradNodeBase {
public:
// Constructor: configure fwd input tensors to grad node
GradNodeAccumulation() : GradNodeBase(1, 1) { SetDefaultGradInOutMeta(); }
~GradNodeAccumulation() override = default;
// Functor: perform backward computations
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
void RetainGrad(
const std::function<egr::EagerTensor(const egr::EagerTensor&)>& hook);
private:
egr::EagerTensor accumulated_grad;
std::function<egr::EagerTensor(const egr::EagerTensor&)> retain_grad_hook_;
};
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/api/all.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/include/core.h"
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/npu_op_runner.h"
#endif
namespace egr {
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
public:
TensorAddFunctor(int64_t numel, const T* x, T* y)
: numel_(numel), x_(x), y_(y) {}
void operator()(const paddle::platform::CPUPlace& place) {
paddle::platform::CPUDeviceContext* ctx =
dynamic_cast<paddle::platform::CPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
auto blas =
paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext, T>(
*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
// TODO(jiabin): Support xpu here from gradient_accumulator.cc
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void operator()(const paddle::platform::CUDAPlace& place) {
paddle::platform::CUDADeviceContext* ctx =
dynamic_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
auto blas =
paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#else
void operator()(const paddle::platform::CUDAPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
// TODO(jiabin): Support Npu here from gradient_accumulator.cc
// there is NO blas in CUDAPinnedPlace
void operator()(const paddle::platform::CUDAPinnedPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#ifdef PADDLE_WITH_ASCEND_CL
void operator()(const paddle::platform::NPUPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::NPUPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_XPU
void operator()(const paddle::platform::XPUPlace& place) {
paddle::platform::XPUDeviceContext* ctx =
dynamic_cast<paddle::platform::XPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const paddle::platform::XPUPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
void operator()(const paddle::platform::NPUPinnedPlace& place) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
private:
int64_t numel_;
const T* x_;
T* y_;
};
template <typename DeviceContext, typename T>
void TensorAddImpl(const std::shared_ptr<pten::DenseTensor>& src,
pten::DenseTensor* dst,
const paddle::platform::Place& place) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
paddle::operators::math::ElementwiseAddTo<DeviceContext, T> func;
func(dev_ctx, *(src.get()), dst);
}
template <typename DeviceContext, typename T>
void TensorAddImpl(const paddle::framework::Tensor& src,
paddle::framework::Tensor* dst,
const paddle::platform::Place& place) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
paddle::operators::math::ElementwiseAddTo<DeviceContext, T> func;
func(dev_ctx, src, dst);
}
void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support other tensor type later
std::shared_ptr<pten::DenseTensor> dst_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl());
std::shared_ptr<pten::DenseTensor> src_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(src.impl());
auto numel = src_tensor->numel();
if (numel == 0) {
return;
}
PADDLE_ENFORCE_EQ(
dst_tensor->numel(), numel,
paddle::platform::errors::PreconditionNotMet(
"The number of elements of source tensor and destination tensor "
"should be equal, but got the number of elements of source tensor is "
"%zu and the number of elements of destination tensor is %zu.",
numel, dst_tensor->numel()));
auto data_type = pten::TransToProtoVarType(src_tensor->dtype());
auto place = src_tensor->place();
PADDLE_ENFORCE_EQ(pten::TransToProtoVarType(dst_tensor->dtype()), data_type,
paddle::platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == paddle::framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func(numel, src_tensor->data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>()); \
boost::apply_visitor(func, place); \
return; \
}
// TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#undef PADDLE_TENSOR_ADD
if (data_type == paddle::framework::proto::VarType::FP16) {
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(src_tensor,
dst_tensor.get(), place);
#else
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
#endif
} else if (paddle::platform::is_cpu_place(place)) {
return TensorAddImpl<paddle::platform::CPUDeviceContext,
paddle::platform::float16>(src_tensor,
dst_tensor.get(), place);
}
}
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
}
void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support other tensor type later
auto* dst_tensor =
dst->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
auto& src_tensor = src.Var().Get<paddle::framework::LoDTensor>();
auto numel = src_tensor.numel();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (numel == 0) {
return;
}
PADDLE_ENFORCE_EQ(
dst_tensor->numel(), numel,
paddle::platform::errors::PreconditionNotMet(
"The number of elements of source tensor and destination tensor "
"should be equal, but got the number of elements of source tensor is "
"%zu and the number of elements of destination tensor is %zu.",
numel, dst_tensor->numel()));
auto data_type = src_tensor.type();
auto place = src_tensor.place();
PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
paddle::platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == paddle::framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \
numel, src_tensor.data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>(place)); \
boost::apply_visitor(func, place); \
return; \
}
// TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#undef PADDLE_TENSOR_ADD
if (data_type == paddle::framework::proto::VarType::FP16) {
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(src_tensor, dst_tensor,
place);
#else
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
#endif
} else if (paddle::platform::is_cpu_place(place)) {
return TensorAddImpl<paddle::platform::CPUDeviceContext,
paddle::platform::float16>(src_tensor, dst_tensor,
place);
}
}
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
}
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/pten/api/all.h"
namespace egr {
// Accumulation API
void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst);
void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst);
} // namespace egr
......@@ -248,6 +248,14 @@ class EagerTensor final {
void ResetVar(const paddle::framework::Variable& src) { var_ = src; }
const std::shared_ptr<paddle::experimental::Tensor>& Tensor() const {
return tensor_;
}
void set_tensor(const std::shared_ptr<paddle::experimental::Tensor>& tensor) {
tensor_ = tensor;
}
private:
template <typename LEGACY_TYPE, typename TYPE>
void SetImplWithLegacyTensor() {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace egr {
static void FillUnderlyingVariableWithValue(
double value, const paddle::framework::DDim& ddim,
const paddle::platform::Place& place,
const paddle::framework::proto::VarType::Type& dtype,
egr::EagerTensor* target) {
auto* dst_tensor =
target->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
dst_tensor->Resize(ddim);
// TOOD(jiabin): Ugly fix here we have fwd_data_type_ and data_type, since in
// grad mission
// we can't get data_type_ directly. We need to check if we can only use
// default data_type for now.
dst_tensor->mutable_data(place, dtype);
paddle::operators::math::set_constant(*dev_ctx, dst_tensor, value);
}
void GradTensorHolder::add(size_t slot_id, size_t rank,
const egr::EagerTensor& t, bool fill_one) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() "
"which exceeds size of buffer"));
VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id
<< ", size: " << buffer_[slot_id].size();
if (buffer_[slot_id].empty()) {
VLOG(6) << "Pass add Tensor for buffer_ slot: " << slot_id
<< " since its buffer_ is empty ";
return;
}
PADDLE_ENFORCE(
rank < buffer_[slot_id].size(),
paddle::platform::errors::Fatal(
"Invalid rank for GradTensorHolder::add() which exceeds size "
"of buffer slot %d, got slot size is: %d rank is: %d",
slot_id, buffer_[slot_id].size(), rank));
egr::EagerTensor& buffer_tensor = buffer_[slot_id][rank];
if (!fill_one) {
// TODO(jiabin): Code bellow is ugly to divide which inner var we used,
// remove framework::Variable
// related code later.
// This if statement is trying to test neither pten::Tensor nor
// framework::Variable is initialized.
if ((!buffer_tensor.defined() || !buffer_tensor.initialized()) &&
(!buffer_tensor.Var().IsInitialized())) {
// Simply copy tensor->impl
buffer_tensor = t;
} else {
// Accumulation
if (t.initialized() && buffer_tensor.initialized()) {
TensorAdd(t, &buffer_tensor);
} else if (t.Var().IsInitialized() &&
buffer_tensor.Var().IsInitialized()) {
VariableAdd(t, &buffer_tensor);
} else if (t.Var().IsInitialized() && buffer_tensor.initialized()) {
// TODO(jiabin): This can be merge to upper if case.
buffer_tensor.SyncToVar();
VariableAdd(t, &buffer_tensor);
} else if (t.initialized() && buffer_tensor.Var().IsInitialized()) {
buffer_tensor.SyncToTensor();
TensorAdd(t, &buffer_tensor);
} else {
// Should not happend case
// 1. both not init
}
}
} else {
// Create new tensor->impl and fill it with 1.0
if (t.defined()) {
// Fill 1.0
paddle::experimental::Tensor tensor =
paddle::experimental::ones_like(*t.Tensor().get());
buffer_tensor.set_tensor(
std::make_shared<paddle::experimental::Tensor>(tensor));
} else {
// TODO(jiabin): Only Support LodTensorForNow
auto type = paddle::framework::ToVarType(t.Var().Type());
switch (type) {
case paddle::framework::proto::VarType::LOD_TENSOR: {
auto t_ftensor = t.Var().Get<paddle::framework::LoDTensor>();
FillUnderlyingVariableWithValue(1.0, t_ftensor.dims(),
t_ftensor.place(), t_ftensor.type(),
&buffer_tensor);
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Cannot found var type: %s in Fill Constant API",
paddle::framework::ToTypeName(type)));
}
}
}
}
}
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/grad_node_info.h"
namespace egr {
/**
* Input Buffer is designed for backward grad accumulate.
* Since we will have one output used by multi preceding ops in forward pass,
* we will meet a problem that we need to accumulate multiple grads into one.
*
* GradTensorHolder should have as same format as forward output **/
class GradTensorHolder {
public:
explicit GradTensorHolder(const std::vector<GradSlotMeta>& meta) {
VLOG(7) << "Init GradTensorHolder with meta size: " << meta.size();
buffer_.resize(meta.size());
for (size_t i = 0; i < buffer_.size(); i++) {
VLOG(7) << "Init GradTensorHolder with meta rank: " << meta[i].Size();
buffer_[i].resize(meta[i].Size());
}
}
GradTensorHolder(const GradTensorHolder& other) = default;
explicit GradTensorHolder(std::vector<std::vector<egr::EagerTensor>>&& inputs)
: buffer_(std::move(inputs)) {}
GradTensorHolder& operator=(const GradTensorHolder& other) = default;
// Create new tensor and copy tensor->impl
void add(size_t slot_id, size_t rank, const egr::EagerTensor& t,
bool fill_one = false);
const std::vector<egr::EagerTensor>& operator[](const size_t& pos) {
return buffer_[pos];
}
const std::vector<std::vector<egr::EagerTensor>>& Buffers() {
return buffer_;
}
private:
std::vector<std::vector<egr::EagerTensor>> buffer_;
};
} // namespace egr
set(eager_deps pten pten_api pten_tensor utils global_utils autograd_meta grad_node_info)
set(eager_deps pten pten_api utils global_utils pten_tensor autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
add_subdirectory(data_structure_tests)
add_subdirectory(task_tests)
cc_test(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS ${eager_deps} )
cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps} grad_node_info)
cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps} grad_node_info)
cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps} grad_node_info utils)
cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS ${eager_deps})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/kernel_registry.h"
// TODO(jiabin): remove nolint here!!!
using namespace egr; // NOLINT
TEST(AccumulationNode, EagerTensor) {
// Construct Eager Tensor
pten::DenseTensorMeta meta = pten::DenseTensorMeta(
pten::DataType::FLOAT16, paddle::framework::make_ddim({1, 1}));
std::shared_ptr<pten::DenseTensor> dt0 = std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
dt0->mutable_data<paddle::platform::float16>()[0] = 10.0;
EagerTensor et0 = EagerTensor(dt0);
std::shared_ptr<pten::DenseTensor> dt1 = std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
dt1->mutable_data<paddle::platform::float16>()[0] = 20.0;
EagerTensor et1 = EagerTensor(dt1);
std::shared_ptr<pten::DenseTensor> grad_dt =
std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
EagerTensor grad_et = EagerTensor(grad_dt);
// AccumulationNode
GradNodeAccumulation node = GradNodeAccumulation();
// Hook
std::function<egr::EagerTensor(const egr::EagerTensor&)> hook =
[&grad_et](const egr::EagerTensor& t) {
if (t.defined()) {
grad_et.set_impl(t.impl());
return grad_et;
} else {
grad_et.MutableVar()
->GetMutable<paddle::framework::LoDTensor>()
->ShareDataWith(t.Var().Get<paddle::framework::LoDTensor>());
return grad_et;
}
};
node.RetainGrad(hook);
// operator()
EagerTensor ret_et0 = node({{et0}})[0][0];
auto* ret_et0_ptr =
std::dynamic_pointer_cast<pten::DenseTensor>(ret_et0.impl())
->data<paddle::platform::float16>();
CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f));
EagerTensor ret_et1 = node({{et1}})[0][0];
auto* ret_et1_ptr =
std::dynamic_pointer_cast<pten::DenseTensor>(ret_et1.impl())
->data<paddle::platform::float16>();
CHECK_EQ(ret_et1_ptr[0], paddle::platform::float16(30.0f));
// Retain Grad
auto* ret_grad_et_ptr =
std::dynamic_pointer_cast<pten::DenseTensor>(grad_et.impl())
->data<paddle::platform::float16>();
CHECK_EQ(ret_grad_et_ptr[0], paddle::platform::float16(30.0f));
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/kernel_registry.h"
// TODO(jiabin): remove nolint here!!!
using namespace egr; // NOLINT
TEST(GradTensorHolder, Constructor) {
GradSlotMeta slot_meta;
slot_meta.Init(1);
GradTensorHolder grad_tensor_holder = GradTensorHolder({slot_meta});
GradTensorHolder grad_tensor_holder2 = GradTensorHolder(grad_tensor_holder);
// Construct Eager Tensor
pten::DenseTensorMeta meta = pten::DenseTensorMeta(
pten::DataType::FLOAT32, paddle::framework::make_ddim({2, 2}));
std::shared_ptr<pten::DenseTensor> dt = std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
EagerTensor et = EagerTensor(dt);
std::vector<std::vector<EagerTensor>> inputs;
inputs.push_back({et});
GradTensorHolder grad_tensor_holder4 = GradTensorHolder(std::move(inputs));
}
TEST(GradTensorHolder, Interfaces) {
// Construct Eager Tensor
pten::DenseTensorMeta meta = pten::DenseTensorMeta(
pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 1}));
std::shared_ptr<pten::DenseTensor> dt0 = std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
dt0->mutable_data<float>()[0] = 10.0;
EagerTensor et0 = EagerTensor(dt0);
std::shared_ptr<pten::DenseTensor> dt1 = std::make_shared<pten::DenseTensor>(
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace()),
meta);
dt1->mutable_data<float>()[0] = 20.0;
EagerTensor et1 = EagerTensor(dt1);
// Constructor empty GradTensorHolder
GradSlotMeta slot_meta;
slot_meta.Init(1);
GradTensorHolder grad_tensor_holder =
GradTensorHolder({slot_meta, slot_meta});
// add():
// fill one
grad_tensor_holder.add(0, 0, et0, true);
// accumulation
grad_tensor_holder.add(1, 0, et0, false);
grad_tensor_holder.add(1, 0, et1, false);
// Buffers()
const auto& buffers = grad_tensor_holder.Buffers();
CHECK_EQ(static_cast<int>(buffers.size()), 2);
CHECK_EQ(static_cast<int>(buffers[0].size()), 1);
CHECK_EQ(static_cast<int>(buffers[1].size()), 1);
// operator[]
const auto& holder_et0 = grad_tensor_holder[0][0];
const auto& holder_et1 = grad_tensor_holder[1][0];
auto* holder_et0_ptr =
std::dynamic_pointer_cast<pten::DenseTensor>(holder_et0.impl())
->data<float>();
auto* holder_et1_ptr =
std::dynamic_pointer_cast<pten::DenseTensor>(holder_et1.impl())
->data<float>();
CHECK_EQ(holder_et0_ptr[0], 1.0f);
CHECK_EQ(holder_et1_ptr[0], 30.0f);
}
......@@ -61,7 +61,7 @@ math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
math_library(math_function DEPS blas)
math_library(math_function DEPS blas dense_tensor tensor)
math_library(maxouting)
math_library(pooling)
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
......@@ -268,6 +269,13 @@ struct ElementwiseAddTo<platform::CPUDeviceContext, T> {
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
void operator()(platform::CPUDeviceContext* ctx, const pten::DenseTensor& src,
pten::DenseTensor* dst) {
auto in = pten::EigenVector<T>::Flatten(src);
auto out = pten::EigenVector<T>::Flatten(*dst);
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
};
template struct ElementwiseAddTo<platform::CPUDeviceContext, platform::float16>;
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
namespace paddle {
namespace operators {
......@@ -283,10 +284,18 @@ struct ElementwiseAddTo<platform::CUDADeviceContext, T> {
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
void operator()(platform::CUDADeviceContext* ctx,
const pten::DenseTensor& src, pten::DenseTensor* dst) {
auto in = pten::EigenVector<T>::Flatten(src);
auto out = pten::EigenVector<T>::Flatten(*dst);
auto& place = *(ctx->eigen_device());
out.device(place) = out + in;
}
};
template struct ElementwiseAddTo<platform::CUDADeviceContext,
platform::float16>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/dense_tensor.h"
namespace paddle {
namespace operators {
......@@ -65,6 +66,8 @@ struct ElementwiseAddTo {
// dst = dst + src
void operator()(DeviceContext* ctx, const framework::Tensor& src,
framework::Tensor* dst);
void operator()(DeviceContext* ctx, const pten::DenseTensor& src,
pten::DenseTensor* dst);
};
template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册