未验证 提交 0bb3e5f1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Refactor TensorAdd by template (#39282)

* Refactor TensorAdd func by template and remove gradient_accumulation in eager

* Remove needless target name

* Use overload instead of template
上级 fc5fa0de
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node) set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy) set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node) set(generated_deps dygraph_function dygraph_node)
...@@ -12,7 +12,7 @@ add_subdirectory(accumulation) ...@@ -12,7 +12,7 @@ add_subdirectory(accumulation)
add_subdirectory(legacy) add_subdirectory(legacy)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils) cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
......
cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function) cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulator pten pten_api grad_node_info)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info)
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/pten/api/all.h" #include "paddle/pten/api/all.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
...@@ -35,7 +35,7 @@ static void CopyOrAddTensor(egr::EagerTensor* tensor, ...@@ -35,7 +35,7 @@ static void CopyOrAddTensor(egr::EagerTensor* tensor,
*tensor = t; *tensor = t;
} else { } else {
// Accumulation // Accumulation
egr::TensorAdd(t, tensor); paddle::imperative::TensorAdd<egr::EagerTensor>(t, tensor);
} }
} }
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/api/all.h"
#include "paddle/pten/core/convert_utils.h"
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif
namespace egr {
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
public:
TensorAddFunctor(int64_t numel, const T* x, T* y)
: numel_(numel), x_(x), y_(y) {}
void operator()(const paddle::platform::CPUPlace& place) const {
paddle::platform::CPUDeviceContext* ctx =
dynamic_cast<paddle::platform::CPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
auto blas =
paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext, T>(
*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
// TODO(jiabin): Support xpu here from gradient_accumulator.cc
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void operator()(const paddle::platform::CUDAPlace& place) const {
paddle::platform::CUDADeviceContext* ctx =
dynamic_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
auto blas =
paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#else
void operator()(const paddle::platform::CUDAPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
// TODO(jiabin): Support Npu here from gradient_accumulator.cc
// there is NO blas in CUDAPinnedPlace
void operator()(const paddle::platform::CUDAPinnedPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#ifdef PADDLE_WITH_ASCEND_CL
void operator()(const paddle::platform::NPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::NPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_XPU
void operator()(const paddle::platform::XPUPlace& place) const {
paddle::platform::XPUDeviceContext* ctx =
dynamic_cast<paddle::platform::XPUDeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const paddle::platform::XPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_MLU
void operator()(const paddle::platform::MLUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::MLUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
#ifdef PADDLE_WITH_IPU
void operator()(const paddle::platform::IPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#else
void operator()(const paddle::platform::IPUPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
#endif
void operator()(const paddle::platform::NPUPinnedPlace& place) const {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
private:
int64_t numel_;
const T* x_;
mutable T* y_;
};
template <typename DeviceContext, typename T>
void TensorAddImpl(const std::shared_ptr<pten::DenseTensor>& src,
pten::DenseTensor* dst,
const paddle::platform::Place& place) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
paddle::operators::math::ElementwiseAddTo<DeviceContext, T> func;
func(dev_ctx, *(src.get()), dst);
}
template <typename DeviceContext, typename T>
void TensorAddImpl(const paddle::framework::Tensor& src,
paddle::framework::Tensor* dst,
const paddle::platform::Place& place) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<DeviceContext*>(ctx);
paddle::operators::math::ElementwiseAddTo<DeviceContext, T> func;
func(dev_ctx, src, dst);
}
void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support other tensor type later
std::shared_ptr<pten::DenseTensor> dst_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl());
std::shared_ptr<pten::DenseTensor> src_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(src.impl());
auto numel = src_tensor->numel();
if (numel == 0) {
return;
}
PADDLE_ENFORCE_EQ(
dst_tensor->numel(), numel,
paddle::platform::errors::PreconditionNotMet(
"The number of elements of source tensor and destination tensor "
"should be equal, but got the number of elements of source tensor is "
"%zu and the number of elements of destination tensor is %zu.",
numel, dst_tensor->numel()));
auto data_type = pten::TransToProtoVarType(src_tensor->dtype());
auto place = src_tensor->place();
PADDLE_ENFORCE_EQ(pten::TransToProtoVarType(dst_tensor->dtype()), data_type,
paddle::platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));
#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == paddle::framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \
numel, src_tensor->data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>(place)); \
paddle::platform::VisitPlace(place, func); \
return; \
}
// TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float
#ifndef PADDLE_WITH_XPU
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD
if (data_type == paddle::framework::proto::VarType::FP16) {
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(src_tensor,
dst_tensor.get(), place);
#else
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
#endif
} else if (paddle::platform::is_cpu_place(place)) {
return TensorAddImpl<paddle::platform::CPUDeviceContext,
paddle::platform::float16>(src_tensor,
dst_tensor.get(), place);
}
}
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
}
void VariableAdd(const egr::EagerTensor& src_tensor,
egr::EagerTensor* dst_tensor) {
auto& src = src_tensor.Var();
auto* dst = dst_tensor->MutableVar();
if (dst->IsType<paddle::framework::LoDTensor>()) {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::imperative::TensorAdd(src, dst);
} else if (src.IsType<pten::SelectedRows>()) {
paddle::imperative::SelectedRowsAddToTensor(src, dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
} else {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::framework::Variable new_dst;
paddle::imperative::SelectedRowsAddTensor(*dst, src, &new_dst);
*dst = std::move(new_dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
}
}
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/pten/api/all.h"
namespace egr {
// Accumulation API
void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst);
void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst);
} // namespace egr
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/grad_tensor_holder.h" #include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h" #include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -72,17 +72,17 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, ...@@ -72,17 +72,17 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
} else { } else {
// Accumulation // Accumulation
if (t.initialized() && buffer_tensor.initialized()) { if (t.initialized() && buffer_tensor.initialized()) {
TensorAdd(t, &buffer_tensor); paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else if (t.Var().IsInitialized() && } else if (t.Var().IsInitialized() &&
buffer_tensor.Var().IsInitialized()) { buffer_tensor.Var().IsInitialized()) {
VariableAdd(t, &buffer_tensor); paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.Var().IsInitialized() && buffer_tensor.initialized()) { } else if (t.Var().IsInitialized() && buffer_tensor.initialized()) {
// TODO(jiabin): This can be merge to upper if case. // TODO(jiabin): This can be merge to upper if case.
buffer_tensor.SyncToVar(); buffer_tensor.SyncToVar();
VariableAdd(t, &buffer_tensor); paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.initialized() && buffer_tensor.Var().IsInitialized()) { } else if (t.initialized() && buffer_tensor.Var().IsInitialized()) {
buffer_tensor.SyncToTensor(); buffer_tensor.SyncToTensor();
TensorAdd(t, &buffer_tensor); paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else { } else {
// Should not happend case // Should not happend case
// 1. both not init // 1. both not init
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags) cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
IF(WITH_XPU) IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils) cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ELSE() ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ENDIF() ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit) add_subdirectory(jit)
......
...@@ -214,9 +214,37 @@ void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst, ...@@ -214,9 +214,37 @@ void TensorAddImpl(const framework::Tensor& src, framework::Tensor* dst,
func(dev_ctx, src, dst); func(dev_ctx, src, dst);
} }
void TensorAdd(const framework::Variable& src, framework::Variable* dst) { std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(egr::EagerTensor* dst) {
std::shared_ptr<pten::DenseTensor> dst_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(dst->impl());
return dst_tensor;
}
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
const egr::EagerTensor& src) {
std::shared_ptr<pten::DenseTensor> dst_tensor =
std::dynamic_pointer_cast<pten::DenseTensor>(src.impl());
return dst_tensor;
}
std::shared_ptr<pten::DenseTensor> GetInnerDstTensor(framework::Variable* dst) {
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>(); auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
return std::make_shared<pten::DenseTensor>(*dst_tensor);
}
std::shared_ptr<pten::DenseTensor> GetInnerSrcTensor(
const framework::Variable& src) {
auto& src_tensor = src.Get<framework::LoDTensor>(); auto& src_tensor = src.Get<framework::LoDTensor>();
return std::make_shared<pten::DenseTensor>(src_tensor);
}
template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) {
std::shared_ptr<pten::DenseTensor> d_tensor = GetInnerDstTensor(dst);
std::shared_ptr<pten::DenseTensor> s_tensor = GetInnerSrcTensor(src);
auto* dst_tensor = d_tensor.get();
auto& src_tensor = *s_tensor.get();
auto numel = src_tensor.numel(); auto numel = src_tensor.numel();
...@@ -336,6 +364,11 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -336,6 +364,11 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
framework::DataTypeToString(data_type), place)); framework::DataTypeToString(data_type), place));
} }
template void TensorAdd<framework::Variable>(const framework::Variable& src,
framework::Variable* dst);
template void TensorAdd<egr::EagerTensor>(const egr::EagerTensor& src,
egr::EagerTensor* dst);
void SelectedRowsAddToTensor(const framework::Variable& src, void SelectedRowsAddToTensor(const framework::Variable& src,
framework::Variable* dst) { framework::Variable* dst) {
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>(); auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
...@@ -462,13 +495,41 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge( ...@@ -462,13 +495,41 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
framework::DataTypeToString(data_type))); framework::DataTypeToString(data_type)));
} }
void VariableAdd(const egr::EagerTensor& src_tensor,
egr::EagerTensor* dst_tensor) {
auto& src = src_tensor.Var();
auto* dst = dst_tensor->MutableVar();
if (dst->IsType<paddle::framework::LoDTensor>()) {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::imperative::TensorAdd<paddle::framework::Variable>(src, dst);
} else if (src.IsType<pten::SelectedRows>()) {
paddle::imperative::SelectedRowsAddToTensor(src, dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
} else {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::framework::Variable new_dst;
paddle::imperative::SelectedRowsAddTensor(*dst, src, &new_dst);
*dst = std::move(new_dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
}
}
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* dst_var, bool unchange_input) { VariableWrapper* dst_var, bool unchange_input) {
auto& src = var->Var(); auto& src = var->Var();
auto* dst = dst_var->MutableVar(); auto* dst = dst_var->MutableVar();
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
if (src.IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) {
TensorAdd(src, dst); TensorAdd<framework::Variable>(src, dst);
} else if (src.IsType<pten::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
SelectedRowsAddToTensor(src, dst); SelectedRowsAddToTensor(src, dst);
} else { } else {
...@@ -535,7 +596,7 @@ void GradientAccumulator::AccumulateGrad() { ...@@ -535,7 +596,7 @@ void GradientAccumulator::AccumulateGrad() {
"previous gradient."; "previous gradient.";
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
if (src->IsType<framework::LoDTensor>()) { if (src->IsType<framework::LoDTensor>()) {
TensorAdd(*src, dst); TensorAdd<framework::Variable>(*src, dst);
} else if (src->IsType<pten::SelectedRows>()) { } else if (src->IsType<pten::SelectedRows>()) {
SelectedRowsAddToTensor(*src, dst); SelectedRowsAddToTensor(*src, dst);
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
...@@ -170,7 +171,10 @@ void SelectedRowsAddTensor(const framework::Variable& src_selected_rows_var, ...@@ -170,7 +171,10 @@ void SelectedRowsAddTensor(const framework::Variable& src_selected_rows_var,
const framework::Variable& src_tensor_var, const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var); framework::Variable* dst_tensor_var);
void TensorAdd(const framework::Variable& src, framework::Variable* dst); template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst);
void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -28,8 +28,6 @@ namespace framework = paddle::framework; ...@@ -28,8 +28,6 @@ namespace framework = paddle::framework;
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
void TensorAdd(const framework::Variable& src, framework::Variable* dst);
template <typename Place1, typename Place2, typename T> template <typename Place1, typename Place2, typename T>
int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) { int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) {
framework::Variable var1; framework::Variable var1;
...@@ -69,7 +67,7 @@ int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) { ...@@ -69,7 +67,7 @@ int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) {
sizeof(T) * dst_data.size(), 0); sizeof(T) * dst_data.size(), 0);
#endif #endif
} }
imperative::TensorAdd(var1, &var2); imperative::TensorAdd<framework::Variable>(var1, &var2);
framework::LoDTensor rlt; framework::LoDTensor rlt;
platform::CPUPlace rlt_place; platform::CPUPlace rlt_place;
framework::TensorCopySync(*dst, rlt_place, &rlt); framework::TensorCopySync(*dst, rlt_place, &rlt);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册