未验证 提交 7039bef3 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Dygraph basic impl for semi auto parallel (#55698)

* add phi forward api gen impl

* add phi backward gen code

* polish api code gen impl

* polish code gen impl

* remove auto_paralel namespace

* add dygraph forward impl

* add for_auto_parallel cond

* fix code gen errors

* add dygraph backward impl

* resolve conflict with develop

* refactor dist api gen impl

* revert origin api gen impl

* replace template for override func

* fix dnnl marco error

* revert third_party change

* add with distributed marco

* Update grad_tensor_holder.cc details

* merge dist tensor constructor

* change test tensor to replicate

* fx typo

* resolve conflict with develop

* fix out dim error
上级 fcde3991
...@@ -26,8 +26,12 @@ ...@@ -26,8 +26,12 @@
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
/** /**
* Implementation of GradNodeBase, Edge and GradTensorHolder. * Implementation of GradNodeBase, Edge and GradTensorHolder.
...@@ -121,6 +125,14 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out, ...@@ -121,6 +125,14 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out,
phi::SparseCsrTensor* csr_tensor = phi::SparseCsrTensor* csr_tensor =
static_cast<phi::SparseCsrTensor*>(fwd_out.impl().get()); static_cast<phi::SparseCsrTensor*>(fwd_out.impl().get());
dense_tensor = csr_tensor->mutable_non_zero_elements(); dense_tensor = csr_tensor->mutable_non_zero_elements();
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(fwd_out.impl().get())) {
// TODO(chenweihang): DistTensor contains global and local meta, here
// only set the local meta now, we should set global meta later
dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_out.impl().get())
->mutable_value();
#endif
} else { } else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument."; "non-DenseTensor argument.";
...@@ -256,11 +268,29 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, ...@@ -256,11 +268,29 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
"which is illegal.")); "which is illegal."));
meta.SetTensorMeta(dense_tensor->meta()); meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place()); meta.SetPlace(fwd_in.place());
} #ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(fwd_in.impl().get())) {
phi::DenseTensor* dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_in.impl().get())
->mutable_value();
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype,
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place());
#endif
} else { } else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " VLOG(7)
<< "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument."; "non-DenseTensor argument.";
} }
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta because the Tensor "
"is not initialized.";
}
} }
/* /*
...@@ -367,7 +397,8 @@ void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in, ...@@ -367,7 +397,8 @@ void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in,
// Record TensorMeta // Record TensorMeta
if (fwd_in_tensor.impl() && fwd_in_tensor.impl().get()) { if (fwd_in_tensor.impl() && fwd_in_tensor.impl().get()) {
if (phi::DenseTensor::classof(fwd_in_tensor.impl().get())) { if (phi::DenseTensor::classof(fwd_in_tensor.impl().get())) {
// Only Copy Meta // TODO(chenweihang): DistTensor contains global and local meta, here
// only set the local meta now, we should set global meta later
phi::DenseTensor* dense_tensor = phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get()); static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get());
PADDLE_ENFORCE_NE(dense_tensor->dtype(), PADDLE_ENFORCE_NE(dense_tensor->dtype(),
......
...@@ -20,6 +20,10 @@ ...@@ -20,6 +20,10 @@
#include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace egr { namespace egr {
...@@ -83,6 +87,23 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id, ...@@ -83,6 +87,23 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
} else if (t.is_sparse_csr_tensor() || t.is_sparse_coo_tensor()) { } else if (t.is_sparse_csr_tensor() || t.is_sparse_coo_tensor()) {
buffer_[slot_id][rank] = buffer_[slot_id][rank] =
paddle::experimental::sparse::full_like(t, 1, t.dtype()); paddle::experimental::sparse::full_like(t, 1, t.dtype());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (t.is_dist_tensor()) {
VLOG(6) << "Create a new dist tensor.";
// TODO(chenweihang): we need a shard_tensor API in C++
// TODO(chenweihang): replace by valid dist_attr later
auto temp =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
auto dense_temp =
std::dynamic_pointer_cast<phi::DenseTensor>(temp.impl());
auto dist_tensor = std::make_shared<phi::distributed::DistTensor>(
dense_temp,
dense_temp->meta(),
std::make_shared<
phi::distributed::auto_parallel::TensorDistAttr>());
temp.set_impl(dist_tensor);
buffer_[slot_id][rank] = temp;
#endif
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Only Support DENSE_TENSOR, SPARSE_COO_TENSOR, SPARSE_CSR_TENSOR " "Only Support DENSE_TENSOR, SPARSE_COO_TENSOR, SPARSE_CSR_TENSOR "
...@@ -178,6 +199,10 @@ void GradTensorHolder::add(size_t slot_id, ...@@ -178,6 +199,10 @@ void GradTensorHolder::add(size_t slot_id,
&buffer_values); &buffer_values);
} }
} }
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (t.is_dist_tensor()) {
buffer_tensor = add_ad_func(t, buffer_tensor);
#endif
} else { } else {
// TODO(jiabin): Support Other TensorBase later // TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function // TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
......
...@@ -297,13 +297,15 @@ void InitDistTensorWithTensor( ...@@ -297,13 +297,15 @@ void InitDistTensorWithTensor(
if (place == src.place()) { if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor = std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl()); std::static_pointer_cast<phi::DenseTensor>(src.impl());
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr)); self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor."; VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else { } else {
std::shared_ptr<phi::DenseTensor> tensor = std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>( std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl()); src.copy_to(place, true).impl());
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr)); self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor."; VLOG(4) << "Different place, do TensorCopy for DistTensor.";
} }
if (src.get_autograd_meta()) { if (src.get_autograd_meta()) {
......
...@@ -249,6 +249,27 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -249,6 +249,27 @@ static PyObject* tensor_method_numpy(TensorObject* self,
place, place,
dense_tensor->Holder()->ptr(), dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size()); dense_tensor->Holder()->size());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (self->tensor.is_dist_tensor()) {
// TODO(chenweihang): deal with DistTensor as local DenseTensor now,
// if the local DenseTensor is shard or partial, do gather or reduce?
VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value();
cpu_tensor.set_meta(dense_tensor.meta());
// deep copy
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
// deep copy
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
place,
dense_tensor.Holder()->ptr(),
dense_tensor.Holder()->size());
#endif
} else { } else {
VLOG(6) << "Getting DenseTensor's numpy value"; VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor = auto dense_tensor =
...@@ -290,6 +311,22 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -290,6 +311,22 @@ static PyObject* tensor_method_numpy(TensorObject* self,
dense_tensor->Holder()->ptr(), dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size(), dense_tensor->Holder()->size(),
kind); kind);
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (self->tensor.is_dist_tensor()) {
VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value();
cpu_tensor.set_meta(dense_tensor.meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
dense_tensor.Holder()->ptr(),
dense_tensor.Holder()->size(),
kind);
#endif
} else { } else {
VLOG(6) << "Getting DenseTensor's numpy value"; VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor = auto dense_tensor =
......
...@@ -9,6 +9,9 @@ set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/api.h) ...@@ -9,6 +9,9 @@ set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/api.h)
set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/api.cc) set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/api.cc)
set(api_header_file_tmp ${api_header_file}.tmp) set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp) set(api_source_file_tmp ${api_source_file}.tmp)
# dist forward api file
set(dist_api_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/dist_api_gen.py)
# backward api file # backward api file
set(bw_api_gen_file set(bw_api_gen_file
...@@ -21,6 +24,9 @@ set(bw_api_header_file ...@@ -21,6 +24,9 @@ set(bw_api_header_file
set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc) set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc)
set(bw_api_header_file_tmp ${bw_api_header_file}.tmp) set(bw_api_header_file_tmp ${bw_api_header_file}.tmp)
set(bw_api_source_file_tmp ${bw_api_source_file}.tmp) set(bw_api_source_file_tmp ${bw_api_source_file}.tmp)
# dist backward api file
set(dist_bw_api_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/dist_bw_api_gen.py)
# dygraph(intermediate) api file # dygraph(intermediate) api file
set(im_api_gen_file set(im_api_gen_file
...@@ -124,19 +130,37 @@ endif() ...@@ -124,19 +130,37 @@ endif()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml) execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml)
# generate forward api if(WITH_DISTRIBUTE)
execute_process( # generate dist forward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${dist_api_gen_file} --api_yaml_path
${api_yaml_file} ${legacy_api_yaml_file} --api_header_path
${api_header_file_tmp} --api_source_path ${api_source_file_tmp})
# generate dist backward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${dist_bw_api_gen_file} --backward_yaml_path
${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path
${bw_api_header_file_tmp} --backward_source_path
${bw_api_source_file_tmp})
else()
# generate forward api
execute_process(
COMMAND COMMAND
${PYTHON_EXECUTABLE} ${api_gen_file} --api_yaml_path ${api_yaml_file} ${PYTHON_EXECUTABLE} ${api_gen_file} --api_yaml_path ${api_yaml_file}
${legacy_api_yaml_file} --api_header_path ${api_header_file_tmp} ${legacy_api_yaml_file} --api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp}) --api_source_path ${api_source_file_tmp})
# generate backward api # generate backward api
execute_process( execute_process(
COMMAND COMMAND
${PYTHON_EXECUTABLE} ${bw_api_gen_file} --backward_yaml_path ${PYTHON_EXECUTABLE} ${bw_api_gen_file} --backward_yaml_path
${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path ${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path
${bw_api_header_file_tmp} --backward_source_path ${bw_api_source_file_tmp}) ${bw_api_header_file_tmp} --backward_source_path
${bw_api_source_file_tmp})
endif()
# generate fused_op api # generate fused_op api
execute_process( execute_process(
......
...@@ -19,6 +19,13 @@ limitations under the License. */ ...@@ -19,6 +19,13 @@ limitations under the License. */
DECLARE_bool(use_stride_kernel); DECLARE_bool(use_stride_kernel);
#include "glog/logging.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -475,5 +482,26 @@ void TransStride(phi::DeviceContext* dev_ctx, ...@@ -475,5 +482,26 @@ void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from, phi::SelectedRows* from,
phi::SelectedRows* to) {} phi::SelectedRows* to) {}
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
if (out) {
// TODO(chenweihang): now all dist case are nullptr
if (out->impl() == nullptr) {
auto dense_t = std::make_shared<phi::DenseTensor>();
// TODO(chenweihang): polish code, dist_attr is null now
auto dist_attr =
std::make_shared<phi::distributed::auto_parallel::TensorDistAttr>();
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
dense_t, phi::DenseTensorMeta(), dist_attr);
out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
}
return nullptr;
}
#endif
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -24,6 +24,12 @@ limitations under the License. */ ...@@ -24,6 +24,12 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/string_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -127,5 +133,11 @@ void TransStride(phi::DeviceContext* dev_ctx, ...@@ -127,5 +133,11 @@ void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from, phi::SelectedRows* from,
phi::SelectedRows* to); phi::SelectedRows* to);
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out);
#endif
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -27,6 +27,10 @@ limitations under the License. */ ...@@ -27,6 +27,10 @@ limitations under the License. */
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/contiguous_kernel.h" #include "paddle/phi/kernels/contiguous_kernel.h"
#include "paddle/phi/kernels/transfer_layout_kernel.h" #include "paddle/phi/kernels/transfer_layout_kernel.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
DECLARE_bool(use_stride_kernel); DECLARE_bool(use_stride_kernel);
namespace paddle { namespace paddle {
...@@ -567,5 +571,46 @@ void TransDataBackend(const phi::SelectedRows* tensor, ...@@ -567,5 +571,46 @@ void TransDataBackend(const phi::SelectedRows* tensor,
} }
} }
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
phi::DenseTensor& dense_tensor = *(dist_tensor->mutable_value());
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
phi::DenseTensor out = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(chenweihang): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
return std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(std::move(out)),
dist_tensor->meta(),
dist_tensor->dist_attr());
}
return nullptr;
}
#endif
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,12 @@ limitations under the License. */ ...@@ -20,6 +20,12 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -165,5 +171,16 @@ inline bool NeedTransformPlace(const phi::Place& src_place, ...@@ -165,5 +171,16 @@ inline bool NeedTransformPlace(const phi::Place& src_place,
return ret; return ret;
} }
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
// TODO(chenweihang): impl Reshard input and output function
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
#endif
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -173,6 +173,36 @@ struct KernelTypeParser : ArgsIterator<KernelTypeParser> { ...@@ -173,6 +173,36 @@ struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
} }
}; };
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = true;
void operator()(const Tensor& x) { result &= x.is_dist_tensor(); }
void operator()(const paddle::optional<Tensor>& x) {
if (x) {
result &= x.get_ptr()->is_dist_tensor();
}
}
void operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
result &= t.is_dist_tensor();
}
}
}
// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};
#endif
} // namespace detail } // namespace detail
template <typename... Args> template <typename... Args>
...@@ -205,5 +235,12 @@ DataLayout ParseLayout(DataLayout layout); ...@@ -205,5 +235,12 @@ DataLayout ParseLayout(DataLayout layout);
DataLayout ParseLayout(const Tensor& tensor); DataLayout ParseLayout(const Tensor& tensor);
DataLayout ParseLayoutWithInputOrder(DataLayout layout, const Tensor& tensor); DataLayout ParseLayoutWithInputOrder(DataLayout layout, const Tensor& tensor);
#ifdef PADDLE_WITH_DISTRIBUTE
template <typename... Args>
bool AllInputsAreDistTensor(const Args&... args) {
return detail::DistTensorTypeParser().apply(args...).result;
}
#endif
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
此差异已折叠。
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import yaml
from backward_api_gen import BackwardAPI
from dist_api_gen import DistForwardAPI
######################
# Code Gen Templates #
######################
# 1. Create API Outputs
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({});
auto dense_out = dist_out->mutable_value();
"""
INPLACE_OUT_CREATION_TEMPLATE = """
*{} = {};
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = dist_out_{}->mutable_value();
"""
class DistBackwardAPI(DistForwardAPI, BackwardAPI):
def __init__(self, backward_item_yaml):
BackwardAPI.__init__(self, backward_item_yaml)
self.init_dist_api_members()
# override DistForwardAPI's method
def generate_output_creation_code(self) -> str:
# backward api only need to generate kernel outputs
output_num = len(self.outputs['types'])
output_creation_code = ""
if output_num == 1:
self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor':
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE.format(
self.outputs['names'][0]
)
else:
self.vector_output_size_assertion_check()
elif output_num > 1:
for i, out_type in enumerate(self.outputs['types']):
self.dist_output_args.append(f'dist_out_{i}')
self.dense_output_args.append(f'dense_out_{i}')
if out_type == 'Tensor':
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE.format(
i, self.outputs['names'][i], i, i
)
)
else:
self.vector_output_size_assertion_check()
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api
)
)
return output_creation_code
# override DistForwardAPI's method
def generate_return_code(self) -> str:
return "return;"
# override BaseAPI's method
def get_api_func_name(self):
return self.api
# override BaseAPI's method
# The method lookup order are: (DistBackwardAPI.__mro__)
# <class '__main__.DistBackwardAPI'>,
# <class 'dist_api_gen.DistForwardAPI'>,
# <class 'api_gen.ForwardAPI'>,
# <class 'backward_api_gen.BackwardAPI'>,
# <class 'api_base.BaseAPI'>,
# <class 'object'>
# if don't override it, the ForwardAPI's gene_output wiil be called
def gene_output(
self,
out_dtype_list,
out_tensor_type_list=None,
code_indent='',
inplace_flag=False,
):
return BackwardAPI.gene_output(
self,
out_dtype_list,
out_tensor_type_list,
code_indent,
inplace_flag,
)
# override BaseAPI's method
def get_return_type(self, inplace_flag=False):
return BackwardAPI.get_return_type(self)
# override BaseAPI's method
def gene_return_code(self):
return ""
# override BaseAPI's method
def gene_api_declaration(self) -> str:
return BackwardAPI.gene_api_declaration(self)
def header_include():
return """
#include <tuple>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/utils/optional.h"
"""
def source_include(header_file_path, fw_header_file_path):
return f"""
#include "{header_file_path}"
#include <memory>
#include "glog/logging.h"
#include "gflags/gflags.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "{fw_header_file_path}"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
def backward_api_namespace():
return (
"""
namespace paddle {
namespace experimental {
""",
"""
} // namespace experimental
} // namespace paddle
""",
)
def generate_backward_api(
backward_yaml_path,
is_fused_backward_yaml,
header_file_path,
source_file_path,
):
bw_apis = []
for each_api_yaml in backward_yaml_path:
with open(each_api_yaml, 'r') as f:
api_list = yaml.load(f, Loader=yaml.FullLoader)
if api_list:
bw_apis.extend(api_list)
header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w')
namespace = backward_api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
include_header_file = (
"paddle/phi/api/backward/fused_backward_api.h"
if is_fused_backward_yaml
else "paddle/phi/api/backward/backward_api.h"
)
include_fw_header_file = (
"paddle/phi/api/include/fused_api.h"
if is_fused_backward_yaml
else "paddle/phi/api/include/api.h"
)
source_file.write(
source_include(include_header_file, include_fw_header_file)
)
source_file.write(namespace[0])
# not all fused ops supoort dygraph
if is_fused_backward_yaml is True:
new_bw_apis = [
bw_api
for bw_api in bw_apis
if "support_dygraph_mode" in bw_api
and bw_api["support_dygraph_mode"] is True
]
bw_apis = new_bw_apis
for bw_api in bw_apis:
dist_bw_api = DistBackwardAPI(bw_api)
header_file.write(dist_bw_api.gene_api_declaration())
if is_fused_backward_yaml is True:
source_file.write(dist_bw_api.gene_api_code())
else:
source_file.write(dist_bw_api.gene_api_code())
header_file.write(namespace[1])
source_file.write(namespace[1])
header_file.close()
source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ backward API files'
)
parser.add_argument(
'--backward_yaml_path',
help='path to backward yaml file',
nargs='+',
default=['paddle/phi/api/yaml/backward.yaml'],
)
parser.add_argument(
'--is_fused_backward_yaml',
help='flag of fused backward yaml',
action='store_true',
)
parser.add_argument(
'--backward_header_path',
help='output of generated backward header code file',
default='paddle/phi/api/backward/backward_api.h',
)
parser.add_argument(
'--backward_source_path',
help='output of generated backward source code file',
default='paddle/phi/api/lib/backward_api.cc',
)
options = parser.parse_args()
backward_yaml_path = options.backward_yaml_path
is_fused_backward_yaml = options.is_fused_backward_yaml
header_file_path = options.backward_header_path
source_file_path = options.backward_source_path
generate_backward_api(
backward_yaml_path,
is_fused_backward_yaml,
header_file_path,
source_file_path,
)
if __name__ == '__main__':
main()
...@@ -12,6 +12,7 @@ collect_srcs( ...@@ -12,6 +12,7 @@ collect_srcs(
flags.cc flags.cc
errors.cc errors.cc
enforce.cc enforce.cc
storage_properties.cc
os_info.cc os_info.cc
kernel_context.cc kernel_context.cc
ddim.cc ddim.cc
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
class DenseTensorUtils;
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -54,10 +57,10 @@ class DistTensor final ...@@ -54,10 +57,10 @@ class DistTensor final
} }
DistTensor(const std::shared_ptr<phi::DenseTensor>& dense_tensor, DistTensor(const std::shared_ptr<phi::DenseTensor>& dense_tensor,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr) const std::shared_ptr<TensorDistAttr>& dist_attr)
: dist_attr_(dist_attr) { : meta_(meta), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(*dense_tensor); value_ = std::make_unique<DenseTensor>(*dense_tensor);
set_meta(dense_tensor->meta());
} }
~DistTensor() = default; ~DistTensor() = default;
...@@ -121,6 +124,8 @@ class DistTensor final ...@@ -121,6 +124,8 @@ class DistTensor final
void set_meta(const DenseTensorMeta& meta); void set_meta(const DenseTensorMeta& meta);
private: private:
friend class phi::DenseTensorUtils;
DenseTensorMeta meta_; DenseTensorMeta meta_;
std::shared_ptr<TensorDistAttr> dist_attr_{nullptr}; std::shared_ptr<TensorDistAttr> dist_attr_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr}; std::unique_ptr<DenseTensor> value_{nullptr};
......
...@@ -93,6 +93,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval( ...@@ -93,6 +93,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
return std::make_shared<DistTensor>( return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank), std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
out_physical_tensor_cur_rank.meta(),
out_dist_attr); out_dist_attr);
} }
......
...@@ -66,7 +66,9 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval( ...@@ -66,7 +66,9 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids); dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>( return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr); std::make_shared<DenseTensor>(out_all_gather),
out_all_gather.meta(),
out_dist_attr);
} }
} // namespace distributed } // namespace distributed
......
...@@ -17,11 +17,15 @@ limitations under the License. */ ...@@ -17,11 +17,15 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h" #include "paddle/phi/core/string_tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace phi { namespace phi {
...@@ -84,6 +88,12 @@ void MetaTensor::set_dims(const DDim& dims) { ...@@ -84,6 +88,12 @@ void MetaTensor::set_dims(const DDim& dims) {
} else if (phi::SparseCsrTensor::classof(tensor_)) { } else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->dims = dims; ->dims = dims;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dims = dims;
#endif
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name())); "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
...@@ -115,7 +125,12 @@ void MetaTensor::set_dtype(DataType dtype) { ...@@ -115,7 +125,12 @@ void MetaTensor::set_dtype(DataType dtype) {
} else if (phi::SparseCsrTensor::classof(tensor_)) { } else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->dtype = dtype; ->dtype = dtype;
// No need to set dtype #ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dtype = dtype;
#endif
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported settting dtype for `%s`.", tensor_->type_info().name())); "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
...@@ -146,6 +161,12 @@ void MetaTensor::set_layout(DataLayout layout) { ...@@ -146,6 +161,12 @@ void MetaTensor::set_layout(DataLayout layout) {
} else if (phi::SparseCsrTensor::classof(tensor_)) { } else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->layout = layout; ->layout = layout;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->layout = layout;
#endif
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported settting layout for `%s`.", tensor_->type_info().name())); "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
...@@ -156,7 +177,11 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { ...@@ -156,7 +177,11 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(*this); ValidCheck(*this);
ValidCheck(meta_tensor); ValidCheck(meta_tensor);
if (phi::SparseCooTensor::classof(tensor_) || if (phi::SparseCooTensor::classof(tensor_) ||
phi::SparseCsrTensor::classof(tensor_)) { phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
|| phi::distributed::DistTensor::classof(tensor_)
#endif
) {
return; return;
} }
if (meta_tensor.lod().empty()) { if (meta_tensor.lod().empty()) {
...@@ -182,7 +207,11 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) { ...@@ -182,7 +207,11 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
if (phi::DenseTensor::classof(tensor_) || if (phi::DenseTensor::classof(tensor_) ||
phi::SelectedRows::classof(tensor_) || phi::SelectedRows::classof(tensor_) ||
phi::SparseCooTensor::classof(tensor_) || phi::SparseCooTensor::classof(tensor_) ||
phi::SparseCsrTensor::classof(tensor_)) { phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
|| phi::distributed::DistTensor::classof(tensor_)
#endif
) {
share_dims(meta_tensor); share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout()); set_layout(meta_tensor.layout());
...@@ -207,7 +236,12 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ...@@ -207,7 +236,12 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_selected_rows = phi::SelectedRows::classof(tensor_); bool is_selected_rows = phi::SelectedRows::classof(tensor_);
bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_); bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_); bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) { bool is_dist_tensor = false;
#ifdef PADDLE_WITH_DISTRIBUTE
is_dist_tensor = phi::distributed::DistTensor::classof(tensor_);
#endif
if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr ||
is_dist_tensor) {
if (is_selected_rows) { if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.tensor(); const auto in_tensor_base = meta_tensor.tensor();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/storage_properties.h"
namespace phi {
std::unique_ptr<StorageProperties> CopyStorageProperties(
const std::unique_ptr<StorageProperties>& sp) {
if (sp) {
if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>();
result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result;
#ifdef PADDLE_WITH_DNNL
} else if (OneDNNStorageProperties::classof(sp.get())) {
auto result = std::make_unique<OneDNNStorageProperties>();
result->format = static_cast<OneDNNStorageProperties*>(sp.get())->format;
result->mem_desc =
static_cast<OneDNNStorageProperties*>(sp.get())->mem_desc;
return result;
#endif
} else {
return nullptr;
}
}
return nullptr;
}
} // namespace phi
...@@ -28,11 +28,13 @@ namespace phi { ...@@ -28,11 +28,13 @@ namespace phi {
struct StorageProperties { struct StorageProperties {
public: public:
virtual ~StorageProperties() = default; virtual ~StorageProperties() = default;
TypeInfo<StorageProperties> type_info() const { return type_info_; } TypeInfo<StorageProperties> type_info() const { return type_info_; }
private: private:
template <typename T, typename U> template <typename T, typename U>
friend class TypeInfoTraits; friend class TypeInfoTraits;
TypeInfo<StorageProperties> type_info_{ TypeInfo<StorageProperties> type_info_{
TypeInfo<StorageProperties>::kUnknownType}; TypeInfo<StorageProperties>::kUnknownType};
}; };
...@@ -70,29 +72,7 @@ struct OneDNNStorageProperties ...@@ -70,29 +72,7 @@ struct OneDNNStorageProperties
}; };
#endif #endif
static std::unique_ptr<StorageProperties> CopyStorageProperties( std::unique_ptr<StorageProperties> CopyStorageProperties(
const std::unique_ptr<StorageProperties>& sp) { const std::unique_ptr<StorageProperties>& sp);
if (sp) {
if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>();
result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result;
#ifdef PADDLE_WITH_DNNL
} else if (OneDNNStorageProperties::classof(sp.get())) {
auto result = std::make_unique<OneDNNStorageProperties>();
result->format = static_cast<OneDNNStorageProperties*>(sp.get())->format;
result->mem_desc =
static_cast<OneDNNStorageProperties*>(sp.get())->mem_desc;
return result;
#endif
} else {
return nullptr;
}
}
return nullptr;
}
} // namespace phi } // namespace phi
...@@ -21,8 +21,16 @@ limitations under the License. */ ...@@ -21,8 +21,16 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_array.h" #include "paddle/phi/core/tensor_array.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace phi { namespace phi {
// TODO(chenweihang): DenseTensorUtils has been abused during the development
// process, and now its semantics are incorrect. It can not only operate
// DenseTensors, but also other types of Tensors, requiring renaming or
// splitting
class DenseTensorUtils { class DenseTensorUtils {
public: public:
static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) { static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
...@@ -37,6 +45,12 @@ class DenseTensorUtils { ...@@ -37,6 +45,12 @@ class DenseTensorUtils {
return &(tensor->meta_); return &(tensor->meta_);
} }
#ifdef PADDLE_WITH_DISTRIBUTE
static DenseTensorMeta* GetMutableMeta(distributed::DistTensor* tensor) {
return &(tensor->meta_);
}
#endif
static const std::shared_ptr<phi::Allocation>& GetHolder( static const std::shared_ptr<phi::Allocation>& GetHolder(
const DenseTensor& tensor) { const DenseTensor& tensor) {
return tensor.holder_; return tensor.holder_;
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.nn.functional as F
class TestDistTensor(unittest.TestCase): class TestDistTensor(unittest.TestCase):
...@@ -52,5 +53,36 @@ class TestDistTensor(unittest.TestCase): ...@@ -52,5 +53,36 @@ class TestDistTensor(unittest.TestCase):
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr) self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr)
class TestDistTensorForDygraphAPI(unittest.TestCase):
def check_tensor_eq(self, a, b):
np1 = a.numpy()
np2 = b.numpy()
np.testing.assert_allclose(np1, np2, rtol=1e-05)
def create_local_and_dist_tensor_pair(self, np_array):
local_t = paddle.to_tensor(np_array, dtype='float32')
mesh = dist.ProcessMesh([0], dim_names=["x"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
dist_t = dist.shard_tensor(np_array, dist_attr=dist_attr)
local_t.stop_gradient = False
dist_t.stop_gradient = False
return local_t, dist_t
def test_relu_api_for_dist_tensor(self):
x = np.random.random(size=[4, 4]).astype("float32")
local_in, dist_in = self.create_local_and_dist_tensor_pair(x)
local_out = F.relu(local_in)
dist_out = F.relu(dist_in)
self.check_tensor_eq(local_out, dist_out)
# test backward
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_in.grad, dist_in.grad)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -47,7 +47,7 @@ TEST(dist_tensor, constructor) { ...@@ -47,7 +47,7 @@ TEST(dist_tensor, constructor) {
EXPECT_TRUE(x3.initialized()); EXPECT_TRUE(x3.initialized());
auto a = std::make_shared<DenseTensor>(alloc, DenseTensorMeta(dtype, dims)); auto a = std::make_shared<DenseTensor>(alloc, DenseTensorMeta(dtype, dims));
DistTensor x4(a, dist_attr); DistTensor x4(a, a->meta(), dist_attr);
EXPECT_TRUE(x4.defined()); EXPECT_TRUE(x4.defined());
EXPECT_TRUE(x4.initialized()); EXPECT_TRUE(x4.initialized());
} }
......
...@@ -54,7 +54,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistCPU( ...@@ -54,7 +54,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistCPU(
dist_attr->set_process_mesh(mesh); dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>( return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense), dist_attr); std::make_shared<DenseTensor>(input_dense),
input_dense.meta(),
dist_attr);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -87,7 +89,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistGPU( ...@@ -87,7 +89,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistGPU(
dist_attr->set_process_mesh(mesh); dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>( return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense_gpu), dist_attr); std::make_shared<DenseTensor>(input_dense_gpu),
input_dense_gpu.meta(),
dist_attr);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册