“419465223b3cc66a8895f00b4bb234d78522a3a7”上不存在“python/paddle/fluid/tests/unittests/test_dist_fleet_spmt.py”
提交 2e2074ef 编写于 作者: Y yaoxuefeng6

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mod_dataset_v2

...@@ -441,6 +441,7 @@ class SectionWorker : public DeviceWorker { ...@@ -441,6 +441,7 @@ class SectionWorker : public DeviceWorker {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
static void ResetBatchId() { batch_id_ = 0; } static void ResetBatchId() { batch_id_ = 0; }
static void ResetThreadCompletedFlag() { threads_completed = false; }
static std::atomic<int> cpu_id_; static std::atomic<int> cpu_id_;
......
...@@ -41,6 +41,11 @@ message LocalSGDConfig { ...@@ -41,6 +41,11 @@ message LocalSGDConfig {
optional int32 begin_step = 2 [ default = 1 ]; optional int32 begin_step = 2 [ default = 1 ];
} }
message AdaptiveLocalSGDConfig {
optional int32 init_k_steps = 1 [ default = 1 ];
optional int32 begin_step = 2 [ default = 1 ];
}
message GradientMergeConfig { message GradientMergeConfig {
optional int32 k_steps = 1 [ default = 1 ]; optional int32 k_steps = 1 [ default = 1 ];
optional bool avg = 2 [ default = true ]; optional bool avg = 2 [ default = true ];
...@@ -121,6 +126,7 @@ message DistributedStrategy { ...@@ -121,6 +126,7 @@ message DistributedStrategy {
optional bool cudnn_exhaustive_search = 21 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -131,6 +137,7 @@ message DistributedStrategy { ...@@ -131,6 +137,7 @@ message DistributedStrategy {
optional AsyncConfig a_sync_configs = 107; optional AsyncConfig a_sync_configs = 107;
optional LarsConfig lars_configs = 108; optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -251,6 +251,7 @@ void PipelineTrainer::Finalize() { ...@@ -251,6 +251,7 @@ void PipelineTrainer::Finalize() {
} }
root_scope_->DropKids(); root_scope_->DropKids();
SectionWorker::ResetBatchId(); SectionWorker::ResetBatchId();
SectionWorker::ResetThreadCompletedFlag();
} }
Scope* PipelineTrainer::GetWorkerScope(int thread_id) { Scope* PipelineTrainer::GetWorkerScope(int thread_id) {
......
...@@ -196,7 +196,6 @@ void SectionWorker::TrainFiles() { ...@@ -196,7 +196,6 @@ void SectionWorker::TrainFiles() {
if (threads_completed) { if (threads_completed) {
VLOG(3) << "thread " << thread_id_ << " completed."; VLOG(3) << "thread " << thread_id_ << " completed.";
lk.unlock(); lk.unlock();
threads_completed = false;
return; return;
} }
lk.unlock(); lk.unlock();
...@@ -459,7 +458,6 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -459,7 +458,6 @@ void SectionWorker::TrainFilesWithProfiler() {
<< ", mean_time: " << op_total_time[i] / op_count[i]; << ", mean_time: " << op_total_time[i] / op_count[i];
} }
VLOG(0) << "================================"; VLOG(0) << "================================";
threads_completed = false;
return; return;
} }
lk.unlock(); lk.unlock();
......
...@@ -9,7 +9,8 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -9,7 +9,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif() endif()
function(download_data install_dir data_file) function(download_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) string(REGEX MATCH "[^/\\]+$" file_name ${data_file})
if (NOT EXISTS ${install_dir}/${file_name})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${data_file})
endif() endif()
endfunction() endfunction()
......
...@@ -43,9 +43,9 @@ class FakeInitOp : public framework::OperatorBase { ...@@ -43,9 +43,9 @@ class FakeInitOp : public framework::OperatorBase {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"fake init op's output only" "fake init op's output only"
"supports SelectedRows and LoDTensor"); "supports SelectedRows and LoDTensor"));
} }
} }
}; };
......
...@@ -134,7 +134,10 @@ void ListenAndServOp::RunSyncLoop( ...@@ -134,7 +134,10 @@ void ListenAndServOp::RunSyncLoop(
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); platform::errors::PreconditionNotMet(
"Invalid number of blocks in server program. Expected "
"equal or greater than 2. Recieved %zu",
num_blocks));
// Prepare all the server block // Prepare all the server block
std::vector<int> optimize_blocks_list; std::vector<int> optimize_blocks_list;
...@@ -218,7 +221,8 @@ void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope, ...@@ -218,7 +221,8 @@ void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
VLOG(3) << "reset sparse var: " << varname; VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else { } else {
PADDLE_THROW("The type of sparse var should be SelectedRows"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of sparse var should be SelectedRows"));
} }
} }
if (UNLIKELY(reset_all)) { if (UNLIKELY(reset_all)) {
...@@ -235,7 +239,8 @@ void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope, ...@@ -235,7 +239,8 @@ void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(), math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0)); static_cast<float>(0));
} else { } else {
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of dense var should be in [LoDTensor, Tensor]"));
} }
} }
} }
...@@ -254,8 +259,15 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -254,8 +259,15 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2,
PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0); platform::errors::PreconditionNotMet(
"Invalid format of grad_and_id argument. "
"Expected \"grad:block_id\". Recieved %s",
grad_and_id.c_str()));
PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0,
platform::errors::AlreadyExists(
"The gradient name %s has already existed in out_map",
pieces[0].c_str()));
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
(*out_map)[pieces[0]] = block_id; (*out_map)[pieces[0]] = block_id;
...@@ -267,7 +279,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -267,7 +279,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); platform::errors::PreconditionNotMet(
"Invalid number of blocks in server program. Expected "
"equal or greater than 2. Recieved %zu",
num_blocks));
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); block_list.push_back(blkid);
...@@ -342,9 +357,9 @@ void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames, ...@@ -342,9 +357,9 @@ void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
var->IsType<framework::Tensor>()) { var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname); dense_vars_.push_back(varname);
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"The type of received var should be in [SelectedRows, LoDTensor, " "The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor]."); "Tensor]."));
} }
} }
} }
...@@ -450,7 +465,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -450,7 +465,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
split(prefetch_var_name_and_id, ':', &pieces); split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, prefetch_var = " << pieces[0] VLOG(3) << "after split, prefetch_var = " << pieces[0]
<< ", id=" << pieces[1]; << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(
pieces.size(), 2,
platform::errors::PreconditionNotMet(
"Invalid format of prefetch_var_name_and_id argument. "
"Expected \"xxx:xxx\". Recieved %s",
prefetch_var_name_and_id.c_str()));
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id); prefetch_block_id_list.push_back(block_id);
...@@ -476,7 +496,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -476,7 +496,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sparse_grad_name_to_param_name_str) { sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(sparse_grad_name_and_param_name, ':', &pieces); split(sparse_grad_name_and_param_name, ':', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(
pieces.size(), 2,
platform::errors::PreconditionNotMet(
"Invalid format of sparse_grad_name_and_param_name argument. "
"Expected \"xxx:xxx\". Recieved %s",
sparse_grad_name_and_param_name.c_str()));
VLOG(3) << "after split, sparse_grad_name = " << pieces[0] VLOG(3) << "after split, sparse_grad_name = " << pieces[0]
<< ", param_name = " << pieces[1]; << ", param_name = " << pieces[1];
sparse_grad_name_to_param_name[pieces[0]] = pieces[1]; sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
......
...@@ -61,8 +61,15 @@ void elementwise_floor_div(const framework::ExecutionContext &ctx, ...@@ -61,8 +61,15 @@ void elementwise_floor_div(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) { const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>( auto x_dims = x->dims();
ctx, x, y, axis, FloorDivFunctor<T>(), z); auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, FloorDivFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseFloorDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseFloorDivFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -22,10 +22,12 @@ limitations under the License. */ ...@@ -22,10 +22,12 @@ limitations under the License. */
#include <cblas.h> #include <cblas.h>
#endif #endif
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -63,6 +65,55 @@ DEFINE_CPU_TRANS(4); ...@@ -63,6 +65,55 @@ DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5); DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6); DEFINE_CPU_TRANS(6);
template <typename T>
struct TransposeNormal<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& in, framework::Tensor* out,
const std::vector<int>& axis) {
const int rank = axis.size();
auto in_stride = framework::stride(in.dims());
auto out_stride = framework::stride(out->dims());
const T* in_ptr = in.data<T>();
T* out_ptr = out->data<T>();
auto transpose_helper = [&](int64_t beg, int64_t end) {
for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
// calculate the input index
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride[i];
tmp_idx -= coordinate * out_stride[i];
in_idx += coordinate * in_stride[axis[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
};
double cost_per_iteration =
rank * (Eigen::TensorOpCost::DivCost<int64_t>() +
2 * Eigen::TensorOpCost::MulCost<int64_t>() +
2 * Eigen::TensorOpCost::AddCost<int64_t>());
Eigen::TensorOpCost cost(sizeof(T), sizeof(T), cost_per_iteration);
auto* cpu_device = context.eigen_pool_device();
cpu_device->parallelFor(out->numel(), cost, std::move(transpose_helper));
}
};
// define transpose normal
#define DEFINE_CPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<platform::CPUDeviceContext, TYPE>
DEFINE_CPU_TRANS_NORMAL(platform::float16);
DEFINE_CPU_TRANS_NORMAL(platform::bfloat16);
DEFINE_CPU_TRANS_NORMAL(float);
DEFINE_CPU_TRANS_NORMAL(double);
DEFINE_CPU_TRANS_NORMAL(int);
DEFINE_CPU_TRANS_NORMAL(int64_t);
DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
struct TensorSetConstantCPU { struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value) TensorSetConstantCPU(framework::Tensor* tensor, float value)
: tensor_(tensor), value_(value) {} : tensor_(tensor), value_(value) {}
......
...@@ -11,8 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
...@@ -23,6 +26,7 @@ namespace operators { ...@@ -23,6 +26,7 @@ namespace operators {
namespace math { namespace math {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>; template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>; template struct SetConstant<platform::CUDADeviceContext, float>;
...@@ -31,12 +35,13 @@ template struct SetConstant<platform::CUDADeviceContext, int>; ...@@ -31,12 +35,13 @@ template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>; template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>; template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \ template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \ template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \ template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>;
DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(1);
...@@ -46,6 +51,88 @@ DEFINE_GPU_TRANS(4); ...@@ -46,6 +51,88 @@ DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5); DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6); DEFINE_GPU_TRANS(6);
#define REINTERPRET(T, DST_PTR, SRC_PTR) \
T* DST_PTR = reinterpret_cast<T*>(SRC_PTR)
template <typename T>
__global__ void TransposeNormalKernel(const T* in_ptr, T* out_ptr,
int64_t element,
const int64_t* in_stride_ptr,
const int64_t* out_stride_ptr,
const int64_t* axis_ptr, int rank) {
CUDA_KERNEL_LOOP(out_idx, element) {
int64_t in_idx = 0;
int64_t tmp_idx = out_idx;
for (int i = 0; i < rank; ++i) {
const int64_t coordinate = tmp_idx / out_stride_ptr[i];
tmp_idx -= coordinate * out_stride_ptr[i];
in_idx += coordinate * in_stride_ptr[axis_ptr[i]];
}
out_ptr[out_idx] = in_ptr[in_idx];
}
}
template <typename T>
struct TransposeNormal<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& in, framework::Tensor* out,
const std::vector<int>& axis) {
const int rank = axis.size();
auto in_stride = framework::stride(in.dims());
auto out_stride = framework::stride(out->dims());
auto* in_ptr = in.data<T>();
auto* out_ptr = out->data<T>();
// copy in_stride, out_stride, axis to gpu device
const platform::CUDAPlace& cuda_place =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());
platform::CPUPlace cpu_place = platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = memory::AllocShared(cpu_place, size);
auto cuda_buf_holder = memory::AllocShared(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) {
cpu_buf[i] = in_stride[i];
cpu_buf[rank + i] = out_stride[i];
cpu_buf[2 * rank + i] = axis[i];
}
memory::Copy(cuda_place, cuda_buf, cpu_place, cpu_buf, size,
context.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM =
context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
int64_t elements = in.numel();
int block_size = (elements >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(elements)));
int grid_size = elements / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
TransposeNormalKernel<T><<<grid_size, block_size, 0, context.stream()>>>(
in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr,
rank);
}
};
// define transpose normal
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<platform::CUDADeviceContext, TYPE>
DEFINE_GPU_TRANS_NORMAL(float16);
DEFINE_GPU_TRANS_NORMAL(bfloat16);
DEFINE_GPU_TRANS_NORMAL(float);
DEFINE_GPU_TRANS_NORMAL(double);
DEFINE_GPU_TRANS_NORMAL(int);
DEFINE_GPU_TRANS_NORMAL(int64_t);
DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
struct TensorSetConstantGPU { struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context, TensorSetConstantGPU(const platform::DeviceContext& context,
framework::Tensor* tensor, float value) framework::Tensor* tensor, float value)
......
...@@ -26,6 +26,14 @@ limitations under the License. */ ...@@ -26,6 +26,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename DeviceContext, typename T>
struct TransposeNormal {
// for dims >= 7 situation
void operator()(const DeviceContext& context, const framework::Tensor& in,
framework::Tensor* out, const std::vector<int>& axis);
};
template <typename DeviceContext, typename T, int Rank> template <typename DeviceContext, typename T, int Rank>
struct Transpose { struct Transpose {
void operator()(const DeviceContext& context, const framework::Tensor& in, void operator()(const DeviceContext& context, const framework::Tensor& in,
......
...@@ -226,8 +226,8 @@ TEST(math_funciton, set_constant) { ...@@ -226,8 +226,8 @@ TEST(math_funciton, set_constant) {
for (int64_t i = 0; i < t.numel(); ++i) { for (int64_t i = 0; i < t.numel(); ++i) {
PADDLE_ENFORCE_EQ(10, t.data<int>()[i], PADDLE_ENFORCE_EQ(10, t.data<int>()[i],
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Each value of input" "Each value of input tensor should be 10, "
"tensor should be 10, but received %d.", "but received %d.",
t.data<int>()[i])); t.data<int>()[i]));
} }
delete ctx; delete ctx;
......
...@@ -33,10 +33,10 @@ namespace math { ...@@ -33,10 +33,10 @@ namespace math {
class Sampler { class Sampler {
public: public:
explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
PADDLE_ENFORCE_GT(range, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"Range should be" range, 0,
" greater than 0, but recevied %d.", platform::errors::InvalidArgument(
range)); "Range should be greater than 0, but recevied %d.", range));
if (seed == 0) { if (seed == 0) {
std::random_device r; std::random_device r;
seed_ = r(); seed_ = r();
......
...@@ -34,16 +34,15 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> { ...@@ -34,16 +34,15 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col, const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
vol.dims().size(), 4, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of" "The dimension of vol should be 4, but received %d.",
" vol should be 4, but received %d.", vol.dims().size()));
vol.dims().size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(col->dims().size(), 7,
col->dims().size(), 7, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of" "The dimension of col should be 7, but received %d.",
"col should be 7, but received %d.", col->dims().size()));
col->dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
...@@ -152,16 +151,15 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> { ...@@ -152,16 +151,15 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol, const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
vol->dims().size(), 4, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of vol" "The dimension of vol should be 4, but received %d.",
" should be 4, but received %d.", vol->dims().size()));
vol->dims().size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(col.dims().size(), 7,
col.dims().size(), 7, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of col" "The dimension of col should be 7, but received %d.",
" should be 7, but received %d.", col.dims().size()));
col.dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
...@@ -192,29 +190,29 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> { ...@@ -192,29 +190,29 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
((dilations[0] * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] + strides[0] +
1; 1;
PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_depth_tmp, output_depth,
"input_depth(%d)" platform::errors::InvalidArgument(
" and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth)); input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down - auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
strides[1] + strides[1] +
1; 1;
PADDLE_ENFORCE_EQ(input_height_tmp, output_height, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_height_tmp, output_height,
"input_height(%d)" platform::errors::InvalidArgument(
" and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height)); input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right - auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
strides[2] + strides[2] +
1; 1;
PADDLE_ENFORCE_EQ(input_width_tmp, output_width, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_width_tmp, output_width,
"input_width(%d)" platform::errors::InvalidArgument(
" and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width)); input_width_tmp, output_width));
T* vol_data = vol->data<T>(); T* vol_data = vol->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
......
...@@ -90,16 +90,14 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> { ...@@ -90,16 +90,14 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col, const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
vol.dims().size(), 4, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of" "The dimension of vol should be 4, but received %d.",
" vol should be 4, but received %d.", vol.dims().size()));
vol.dims().size())); PADDLE_ENFORCE_EQ(col->dims().size(), 7,
PADDLE_ENFORCE_EQ( platform::errors::InvalidArgument(
col->dims().size(), 7, "The dimension of col should be 7, but received %d.",
platform::errors::InvalidArgument("The dimension of" col->dims().size()));
"col should be 7, but received %d.",
col->dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
...@@ -253,16 +251,14 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> { ...@@ -253,16 +251,14 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol, const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
vol->dims().size(), 4, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The dimension of vol" "The dimension of vol should be 4, but received %d.",
" should be 4, but received %d.", vol->dims().size()));
vol->dims().size())); PADDLE_ENFORCE_EQ(col.dims().size(), 7,
PADDLE_ENFORCE_EQ( platform::errors::InvalidArgument(
col.dims().size(), 7, "The dimension of col should be 7, but received %d.",
platform::errors::InvalidArgument("The dimension of col" col.dims().size()));
" should be 7, but received %d.",
col.dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
...@@ -291,29 +287,29 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> { ...@@ -291,29 +287,29 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
((dilations[0] * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] + strides[0] +
1; 1;
PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_depth_tmp, output_depth,
"input_depth(%d)" platform::errors::InvalidArgument(
" and output_depth(%d) are mismatching.", "input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth)); input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down - auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
strides[1] + strides[1] +
1; 1;
PADDLE_ENFORCE_EQ(input_height_tmp, output_height, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_height_tmp, output_height,
"input_height(%d)" platform::errors::InvalidArgument(
" and output_height(%d) are mismatching.", "input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height)); input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right - auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
strides[2] + strides[2] +
1; 1;
PADDLE_ENFORCE_EQ(input_width_tmp, output_width, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( input_width_tmp, output_width,
"input_width(%d)" platform::errors::InvalidArgument(
" and output_width(%d) are mismatching.", "input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width)); input_width_tmp, output_width));
int num_kernels = input_channels * input_depth * input_height * input_width; int num_kernels = input_channels * input_depth * input_height * input_width;
......
...@@ -86,8 +86,10 @@ class ConcatPrimitiveFactory { ...@@ -86,8 +86,10 @@ class ConcatPrimitiveFactory {
concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place, Tensor* output, platform::CPUPlace place,
const mkldnn::engine& mkldnn_engine) { const mkldnn::engine& mkldnn_engine) {
dst_mem = mkldnn::memory(concat_pd.dst_desc(), mkldnn_engine, dst_mem = mkldnn::memory(
output->mutable_data<T>(place)); concat_pd.dst_desc(), mkldnn_engine,
output->mutable_data<T>(place, concat_pd.dst_desc().get_size()));
return concat(concat_pd); return concat(concat_pd);
} }
...@@ -193,7 +195,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -193,7 +195,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
prim_creator.SetSrcDataHandleByIndex( prim_creator.SetSrcDataHandleByIndex(
*srcs, i, to_void_cast<T>(multi_input[i]->data<T>())); *srcs, i, to_void_cast<T>(multi_input[i]->data<T>()));
} }
prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place)); prim_creator.SetDstDataHandle(
*dst_mem,
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
} }
mkldnn::stream astream(mkldnn_engine); mkldnn::stream astream(mkldnn_engine);
......
...@@ -18,9 +18,10 @@ limitations under the License. */ ...@@ -18,9 +18,10 @@ limitations under the License. */
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
namespace paddle { namespace paddle {
...@@ -34,6 +35,110 @@ namespace operators { ...@@ -34,6 +35,110 @@ namespace operators {
} }
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DDim = framework::DDim;
inline void GetShuffledDim(const DDim& src_dims, DDim* dst_dims,
const std::vector<int>& reduced_dims,
std::vector<int>* perm_axis) {
// check if it's a reduced dim
std::vector<bool> src_dims_check(src_dims.size(), false);
size_t src_size = src_dims.size();
size_t reduce_size = reduced_dims.size();
for (size_t i = 0; i < reduce_size; ++i) {
dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]];
(*perm_axis)[src_size - reduce_size + i] = reduced_dims[i];
src_dims_check[reduced_dims[i]] = true;
}
size_t offset = 0;
for (size_t i = 0; i < src_dims_check.size(); ++i) {
bool is_reduced = src_dims_check[i];
if (!is_reduced) {
(*perm_axis)[offset] = i;
dst_dims->at(offset++) = src_dims[i];
}
}
}
template <typename DeviceContext, typename OutT>
void GetShuffledInput(const framework::ExecutionContext& context,
const Tensor* input, Tensor* shuffled_input,
const std::vector<int>& dims) {
DDim shuffled_dims(input->dims());
std::vector<int> perm_axis(input->dims().size());
GetShuffledDim(input->dims(), &shuffled_dims, dims, &perm_axis);
shuffled_input->Resize(shuffled_dims);
shuffled_input->mutable_data<OutT>(context.GetPlace());
math::TransposeNormal<DeviceContext, OutT> trans;
trans(context.template device_context<DeviceContext>(), *input,
shuffled_input, perm_axis);
}
inline void GetOriginDimFromShuffled(const DDim& src_dim,
const std::vector<int>& dims,
std::vector<int>* origin_dim) {
DDim shuffled_dims(src_dim);
size_t n = src_dim.size();
std::vector<int> perm_axis(n);
GetShuffledDim(src_dim, &shuffled_dims, dims, &perm_axis);
for (size_t i = 0; i < n; ++i) {
(*origin_dim)[perm_axis[i]] = i;
}
}
template <typename DeviceContext, typename OutT, typename Functor>
void HandleLargeDim(const framework::ExecutionContext& context,
const Tensor* input, Tensor* output,
const std::vector<int>& dims, bool keep_dim) {
// shuffle the reduced dim to the end
Tensor shuffled_input;
GetShuffledInput<DeviceContext, OutT>(context, input, &shuffled_input, dims);
// transpose to 2D tensor whose shape is {unreduced, reduced}.
const int64_t unreduced = output->numel();
const int64_t reduced = shuffled_input.numel() / unreduced;
shuffled_input.Resize({unreduced, reduced});
DDim output_dim = output->dims();
output->Resize({unreduced});
ReduceFunctor<DeviceContext, OutT, 2, 1, Functor>(
context.template device_context<DeviceContext>(), shuffled_input, output,
{1}, keep_dim);
output->Resize(output_dim);
}
template <typename DeviceContext, typename T, typename Functor>
void HandleLargeDimGrad(const framework::ExecutionContext& context,
const framework::Tensor* x,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
const std::vector<int>& dims) {
const int64_t unreduced = out->numel();
const int64_t reduced = x->numel() / unreduced;
DDim out_dim(out->dims());
DDim x_dim(x->dims());
// transpose and reshape X
Tensor shuffled_x;
GetShuffledInput<DeviceContext, T>(context, x, &shuffled_x, dims);
DDim shuffled_dim = shuffled_x.dims();
shuffled_x.Resize({unreduced, reduced});
// reshape dX {unreduced, reduced}
dx->Resize({unreduced, reduced});
ReduceGradFunctor<DeviceContext, T, 2, Functor>(
context.template device_context<DeviceContext>(), shuffled_x, *out, *dout,
dx, {1});
// transpose dX
std::vector<int> origin_axis(x_dim.size());
GetOriginDimFromShuffled(x_dim, dims, &origin_axis);
Tensor dx_tmp;
framework::TensorCopy(*dx, context.GetPlace(), &dx_tmp);
dx_tmp.Resize(shuffled_dim);
dx->Resize(x_dim);
math::TransposeNormal<DeviceContext, T> trans;
trans(context.template device_context<DeviceContext>(), dx_tmp, dx,
origin_axis);
}
template <typename DeviceContext, typename T, typename Functor> template <typename DeviceContext, typename T, typename Functor>
struct ReduceKernelFunctor { struct ReduceKernelFunctor {
...@@ -69,22 +174,27 @@ struct ReduceKernelFunctor { ...@@ -69,22 +174,27 @@ struct ReduceKernelFunctor {
} else { } else {
int ndim = input->dims().size(); int ndim = input->dims().size();
int rdim = dims.size(); int rdim = dims.size();
HANDLE_DIM(6, 5); if (ndim > 6) {
HANDLE_DIM(6, 4); HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
HANDLE_DIM(6, 3); dims, keep_dim);
HANDLE_DIM(6, 2); } else {
HANDLE_DIM(6, 1); HANDLE_DIM(6, 5);
HANDLE_DIM(5, 4); HANDLE_DIM(6, 4);
HANDLE_DIM(5, 3); HANDLE_DIM(6, 3);
HANDLE_DIM(5, 2); HANDLE_DIM(6, 2);
HANDLE_DIM(5, 1); HANDLE_DIM(6, 1);
HANDLE_DIM(4, 3); HANDLE_DIM(5, 4);
HANDLE_DIM(4, 2); HANDLE_DIM(5, 3);
HANDLE_DIM(4, 1); HANDLE_DIM(5, 2);
HANDLE_DIM(3, 2); HANDLE_DIM(5, 1);
HANDLE_DIM(3, 1); HANDLE_DIM(4, 3);
HANDLE_DIM(2, 1); HANDLE_DIM(4, 2);
HANDLE_DIM(1, 1); HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
} }
} }
}; };
...@@ -137,7 +247,6 @@ class ReduceKernel : public framework::OpKernel<T> { ...@@ -137,7 +247,6 @@ class ReduceKernel : public framework::OpKernel<T> {
} }
} }
}; };
template <typename DeviceContext, typename OutT, typename Functor> template <typename DeviceContext, typename OutT, typename Functor>
class BoolReduceKernel : public framework::OpKernel<OutT> { class BoolReduceKernel : public framework::OpKernel<OutT> {
public: public:
...@@ -175,22 +284,27 @@ class BoolReduceKernel : public framework::OpKernel<OutT> { ...@@ -175,22 +284,27 @@ class BoolReduceKernel : public framework::OpKernel<OutT> {
int ndim = input->dims().size(); int ndim = input->dims().size();
int rdim = dims.size(); int rdim = dims.size();
// comments for accelerating compiling temporarily. // comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5); if (ndim > 6) {
// HANDLE_DIM(6, 4); HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
// HANDLE_DIM(6, 3); dims, keep_dim);
// HANDLE_DIM(6, 2); } else {
// HANDLE_DIM(6, 1); HANDLE_DIM(6, 5);
// HANDLE_DIM(5, 4); HANDLE_DIM(6, 4);
// HANDLE_DIM(5, 3); HANDLE_DIM(6, 3);
// HANDLE_DIM(5, 2); HANDLE_DIM(6, 2);
// HANDLE_DIM(5, 1); HANDLE_DIM(6, 1);
HANDLE_DIM(4, 3); HANDLE_DIM(5, 4);
HANDLE_DIM(4, 2); HANDLE_DIM(5, 3);
HANDLE_DIM(4, 1); HANDLE_DIM(5, 2);
HANDLE_DIM(3, 2); HANDLE_DIM(5, 1);
HANDLE_DIM(3, 1); HANDLE_DIM(4, 3);
HANDLE_DIM(2, 1); HANDLE_DIM(4, 2);
HANDLE_DIM(1, 1); HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
} }
} }
}; };
...@@ -279,6 +393,10 @@ class ReduceGradKernel : public framework::OpKernel<T> { ...@@ -279,6 +393,10 @@ class ReduceGradKernel : public framework::OpKernel<T> {
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0,
*input1, *input2, output, dims); *input1, *input2, output, dims);
break; break;
default:
HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1,
input2, output, dims);
break;
} }
} }
} }
...@@ -313,12 +431,6 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -313,12 +431,6 @@ class ReduceOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of ReduceOp "
"should be less equal than 6. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims));
auto dims = ctx->Attrs().Get<std::vector<int>>("dim"); auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
PADDLE_ENFORCE_GT(dims.size(), 0, PADDLE_ENFORCE_GT(dims.size(), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -402,11 +514,6 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -402,11 +514,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
"Out@GRAD", "ReduceOp"); "Out@GRAD", "ReduceOp");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6,
platform::errors::InvalidArgument(
"Tensors with rank at most 6 are supported by "
"ReduceOp. Received tensor with rank %d.",
x_rank));
auto dims = ctx->Attrs().Get<std::vector<int>>("dim"); auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], x_rank, PADDLE_ENFORCE_LT(dims[i], x_rank,
......
...@@ -53,10 +53,9 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx, ...@@ -53,10 +53,9 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx,
trans6(dev_ctx, in, out, axis); trans6(dev_ctx, in, out, axis);
break; break;
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( // for dim >= 7 situation
"Tensors with rank at most 6 are supported" math::TransposeNormal<DeviceContext, T> trans_normal;
", but received input tensor's rank is %d,", trans_normal(dev_ctx, in, out, axis);
dim));
} }
} }
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <set> #include <set>
#include <string> #include <string>
#include <thread> //NOLINT
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -23,6 +24,7 @@ limitations under the License. */ ...@@ -23,6 +24,7 @@ limitations under the License. */
#endif #endif
#include "glog/logging.h" #include "glog/logging.h"
#include "unsupported/Eigen/CXX11/ThreadPool"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -131,16 +133,31 @@ DeviceContextPool::DeviceContextPool( ...@@ -131,16 +133,31 @@ DeviceContextPool::DeviceContextPool(
CPUDeviceContext::CPUDeviceContext() { CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
InitPoolDevice();
} }
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) { CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
InitPoolDevice();
}
void CPUDeviceContext::InitPoolDevice() {
using EigenEnv = Eigen::StlThreadEnvironment;
using EigenThreadPool = Eigen::ThreadPoolTempl<EigenEnv>;
int num_threads = std::thread::hardware_concurrency();
eigen_threadpool_.reset(new EigenThreadPool(num_threads));
eigen_pool_device_.reset(
new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), num_threads));
} }
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
Eigen::ThreadPoolDevice* CPUDeviceContext::eigen_pool_device() const {
return eigen_pool_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return place_; } Place CPUDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -41,6 +41,7 @@ limitations under the License. */ ...@@ -41,6 +41,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream/cuda_stream.h" #include "paddle/fluid/platform/stream/cuda_stream.h"
#endif #endif
#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
...@@ -65,11 +66,17 @@ class CPUDeviceContext : public DeviceContext { ...@@ -65,11 +66,17 @@ class CPUDeviceContext : public DeviceContext {
Eigen::DefaultDevice* eigen_device() const; Eigen::DefaultDevice* eigen_device() const;
Eigen::ThreadPoolDevice* eigen_pool_device() const;
Place GetPlace() const override; Place GetPlace() const override;
inline void InitPoolDevice();
private: private:
CPUPlace place_; CPUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_pool_device_;
std::unique_ptr<Eigen::ThreadPool> eigen_threadpool_;
}; };
template <typename Place> template <typename Place>
......
...@@ -621,6 +621,7 @@ function generate_upstream_develop_api_spec() { ...@@ -621,6 +621,7 @@ function generate_upstream_develop_api_spec() {
git checkout -b develop_base_pr upstream/$BRANCH git checkout -b develop_base_pr upstream/$BRANCH
cmake_gen $1 cmake_gen $1
build $2 build $2
cp ${PADDLE_ROOT}/python/requirements.txt /tmp
git checkout $cur_branch git checkout $cur_branch
generate_api_spec "$1" "DEV" generate_api_spec "$1" "DEV"
...@@ -641,7 +642,12 @@ function generate_api_spec() { ...@@ -641,7 +642,12 @@ function generate_api_spec() {
cd ${PADDLE_ROOT}/build/.check_api_workspace cd ${PADDLE_ROOT}/build/.check_api_workspace
virtualenv .${spec_kind}_env virtualenv .${spec_kind}_env
source .${spec_kind}_env/bin/activate source .${spec_kind}_env/bin/activate
pip install -r ${PADDLE_ROOT}/python/requirements.txt
if [ "$spec_kind" == "DEV" ]; then
pip install -r /tmp/requirements.txt
else
pip install -r ${PADDLE_ROOT}/python/requirements.txt
fi
pip --no-cache-dir install ${PADDLE_ROOT}/build/python/dist/*whl pip --no-cache-dir install ${PADDLE_ROOT}/build/python/dist/*whl
spec_path=${PADDLE_ROOT}/paddle/fluid/API_${spec_kind}.spec spec_path=${PADDLE_ROOT}/paddle/fluid/API_${spec_kind}.spec
python ${PADDLE_ROOT}/tools/print_signatures.py paddle > $spec_path python ${PADDLE_ROOT}/tools/print_signatures.py paddle > $spec_path
...@@ -930,6 +936,10 @@ function parallel_test_base_gpu() { ...@@ -930,6 +936,10 @@ function parallel_test_base_gpu() {
EOF EOF
set +x set +x
precison_cases=""
if [ ${PRECISION_TEST:-OFF} == "ON" ]; then
precision_cases=`python $PADDLE_ROOT/tools/get_pr_ut.py`
fi
EXIT_CODE=0; EXIT_CODE=0;
test_cases=$(ctest -N -V) # get all test cases test_cases=$(ctest -N -V) # get all test cases
exclusive_tests='' # cases list which would be run exclusively exclusive_tests='' # cases list which would be run exclusively
...@@ -959,10 +969,23 @@ set +x ...@@ -959,10 +969,23 @@ set +x
echo $testcase" will only run at night." echo $testcase" will only run at night."
continue continue
fi fi
if [ ${PRECISION_TEST:-OFF} == "ON" ] && [[ "$precision_cases" != "" ]]; then
will_test="false"
for case in $precision_cases; do
if [[ $testcase == $case ]]; then
will_test="true"
break
fi
done
if [[ $will_test == "false" ]]; then
echo $testcase" won't run in PRECISION_TEST mode."
continue
fi
fi
if [[ "$is_multicard" == "" ]]; then if [[ "$is_multicard" == "" ]]; then
# trick: treat all test case with prefix "test_dist" as dist case, and would run on 2 GPUs # trick: treat all test case with prefix "test_dist" as dist case, and would run on 2 GPUs
read is_multicard <<< $(echo "$testcase"|grep -oEi "test_dist") read is_multicard <<< $(echo "$testcase"|grep -oEi "test_dist_")
fi fi
if [[ "$is_exclusive" != "" ]]; then if [[ "$is_exclusive" != "" ]]; then
...@@ -1077,8 +1100,6 @@ set +x ...@@ -1077,8 +1100,6 @@ set +x
done done
fi fi
if [[ "$EXIT_CODE" != "0" ]]; then if [[ "$EXIT_CODE" != "0" ]]; then
if [[ "$failed_test_lists" == "" ]]; then if [[ "$failed_test_lists" == "" ]]; then
echo "========================================" echo "========================================"
......
...@@ -728,6 +728,63 @@ class DistributedStrategy(object): ...@@ -728,6 +728,63 @@ class DistributedStrategy(object):
"localsgd_configs") "localsgd_configs")
assign_configs_value(self.strategy.localsgd_configs, configs) assign_configs_value(self.strategy.localsgd_configs, configs)
@property
def adaptive_localsgd(self):
"""
Indicating whether we are using Adaptive Local SGD training. Default Value: False
For more details, please refer to `Adaptive Communication Strategies to Achieve
the Best Error-Runtime Trade-off in Local-Update SGD <https://arxiv.org/pdf/1810.08313.pdf>`_.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.adaptive_localsgd = True # by default this is false
"""
return self.strategy.localsgd
@adaptive_localsgd.setter
@is_strict_auto
def adaptive_localsgd(self, flag):
if isinstance(flag, bool):
self.strategy.localsgd = flag
else:
print("WARNING: adaptive_localsgd should have value of bool type")
@property
def adaptive_localsgd_configs(self):
"""
Set AdaptiveLocalSGD training configurations. AdaptiveLocalSGD has a configurable
setting that can be configured through a dict.
**Notes**:
init_k_steps(int) The initial steps for training before adaptive localsgd.
Then, the adaptive localsgd method will modify init_k_steps automatically.
Default 1.
begin_step(int) The step of begining training by adaptive localsgd. Default 1.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.adaptive_localsgd = True
strategy.adaptive_localsgd_configs = {"init_k_steps": 1,
"begin_step": 30}
"""
return get_msg_dict(self.strategy.adaptive_localsgd_configs)
@adaptive_localsgd_configs.setter
@is_strict_auto
def adaptive_localsgd_configs(self, configs):
check_configs_key(self.strategy.adaptive_localsgd_configs, configs,
"adaptive_localsgd_configs")
assign_configs_value(self.strategy.adaptive_localsgd_configs, configs)
@property @property
def dgc(self): def dgc(self):
""" """
......
...@@ -608,25 +608,31 @@ class Fleet(object): ...@@ -608,25 +608,31 @@ class Fleet(object):
@dygraph_only @dygraph_only
def distributed_model(self, model): def distributed_model(self, model):
""" """
Return dygraph distributed data parallel model (Layer) Return distributed data parallel model (Only work in dygraph mode)
Only work in dygraph mode
Args:
model (Layer): the user-defind model which inherits Layer.
Returns:
distributed data parallel model which inherits Layer.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer): import paddle
def __init__(self): import paddle.nn as nn
super(LinearNet, self).__init__() from paddle.distributed import fleet
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1) class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x): def forward(self, x):
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode # 1. enable dynamic mode
paddle.disable_static() paddle.disable_static()
...@@ -658,8 +664,7 @@ class Fleet(object): ...@@ -658,8 +664,7 @@ class Fleet(object):
adam.step() adam.step()
adam.clear_grad() adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
""" """
assert model is not None assert model is not None
self.model = paddle.DataParallel(model) self.model = paddle.DataParallel(model)
...@@ -669,29 +674,30 @@ class Fleet(object): ...@@ -669,29 +674,30 @@ class Fleet(object):
def state_dict(self): def state_dict(self):
""" """
Get state dict information from optimizer. Get state dict information from optimizer.
Only work in dygraph mode (Only work in dygraph mode)
Returns: Returns:
state_dict(dict) : dict contains all the Tensor used by optimizer state_dict(dict) : dict contains all the Tensor used by optimizer
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static() import numpy as np
fleet.init(is_collective=True) import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32") value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value) a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5) layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam) adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer) dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict() state_dict = adam.state_dict()
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.state_dict() return self.user_defined_optimizer.state_dict()
...@@ -700,34 +706,36 @@ class Fleet(object): ...@@ -700,34 +706,36 @@ class Fleet(object):
def set_state_dict(self, state_dict): def set_state_dict(self, state_dict):
""" """
Load optimizer state dict. Load optimizer state dict.
Only work in dygraph mode (Only work in dygraph mode)
Args: Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer state_dict(dict) : Dict contains all the Tensor needed by optimizer
Returns: None Returns:
None
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static() import numpy as np
fleet.init(is_collective=True) import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32") value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value) a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5) layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam) adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer) dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict() state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy") paddle.framework.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy") para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
adam.set_state_dict(opti_state_dict) adam.set_state_dict(opti_state_dict)
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.set_state_dict(state_dict) return self.user_defined_optimizer.set_state_dict(state_dict)
...@@ -736,42 +744,44 @@ class Fleet(object): ...@@ -736,42 +744,44 @@ class Fleet(object):
def set_lr(self, value): def set_lr(self, value):
""" """
Set the value of the learning rate manually in the optimizer. Set the value of the learning rate manually in the optimizer.
Only work in dygraph mode (Only work in dygraph mode)
Args: Args:
value (float|Tensor): the value of learning rate value (float|Tensor): the value of learning rate
Returns: None Returns:
None
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static() import numpy as np
fleet.init(is_collective=True) import paddle
from paddle.distributed import fleet
value = np.arange(26).reshape(2, 13).astype("float32") paddle.disable_static()
a = paddle.fluid.dygraph.to_variable(value) fleet.init(is_collective=True)
layer = paddle.nn.Linear(13, 5) value = np.arange(26).reshape(2, 13).astype("float32")
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) a = paddle.fluid.dygraph.to_variable(value)
adam = fleet.distributed_optimizer(adam) layer = paddle.nn.Linear(13, 5)
dp_layer = fleet.distributed_model(layer) adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
lr_list = [0.2, 0.3, 0.4, 0.5, 0.6] adam = fleet.distributed_optimizer(adam)
for i in range(5): dp_layer = fleet.distributed_model(layer)
adam.set_lr(lr_list[i])
lr = adam.get_lr() lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
print("current lr is {}".format(lr)) for i in range(5):
# Print: adam.set_lr(lr_list[i])
# current lr is 0.2 lr = adam.get_lr()
# current lr is 0.3 print("current lr is {}".format(lr))
# current lr is 0.4 # Print:
# current lr is 0.5 # current lr is 0.2
# current lr is 0.6 # current lr is 0.3
# current lr is 0.4
# current lr is 0.5
# current lr is 0.6
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.set_lr(value) return self.user_defined_optimizer.set_lr(value)
...@@ -780,31 +790,32 @@ class Fleet(object): ...@@ -780,31 +790,32 @@ class Fleet(object):
def get_lr(self): def get_lr(self):
""" """
Get current step learning rate. Get current step learning rate.
Only work in dygraph mode (Only work in dygraph mode)
Returns: Returns:
float: The learning rate of the current step. float: The learning rate of the current step.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static() import numpy as np
fleet.init(is_collective=True) import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32") value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value) a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5) layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()) adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam) adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer) dp_layer = fleet.distributed_model(layer)
lr = adam.get_lr() lr = adam.get_lr()
print(lr) # 0.01 print(lr) # 0.01
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.get_lr() return self.user_defined_optimizer.get_lr()
...@@ -813,27 +824,27 @@ class Fleet(object): ...@@ -813,27 +824,27 @@ class Fleet(object):
def step(self): def step(self):
""" """
Execute the optimizer once. Execute the optimizer once.
Only work in dygraph mode (Only work in dygraph mode)
Returns: None Returns:
None
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.distributed import fleet from paddle.distributed import fleet
class LinearNet(nn.Layer): class LinearNet(nn.Layer):
def __init__(self): def __init__(self):
super(LinearNet, self).__init__() super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10) self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1) self._linear2 = nn.Linear(10, 1)
def forward(self, x): def forward(self, x):
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode # 1. enable dynamic mode
paddle.disable_static() paddle.disable_static()
...@@ -865,8 +876,6 @@ class Fleet(object): ...@@ -865,8 +876,6 @@ class Fleet(object):
adam.step() adam.step()
adam.clear_grad() adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
...@@ -875,28 +884,28 @@ class Fleet(object): ...@@ -875,28 +884,28 @@ class Fleet(object):
@dygraph_only @dygraph_only
def clear_grad(self): def clear_grad(self):
""" """
Execute the optimizer once. Clear the gradients of all optimized parameters for model.
Only work in dygraph mode (Only work in dygraph mode)
Returns: None Returns:
None
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.distributed import fleet from paddle.distributed import fleet
class LinearNet(nn.Layer): class LinearNet(nn.Layer):
def __init__(self): def __init__(self):
super(LinearNet, self).__init__() super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10) self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1) self._linear2 = nn.Linear(10, 1)
def forward(self, x): def forward(self, x):
return self._linear2(self._linear1(x)) return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode # 1. enable dynamic mode
paddle.disable_static() paddle.disable_static()
...@@ -928,8 +937,6 @@ class Fleet(object): ...@@ -928,8 +937,6 @@ class Fleet(object):
adam.step() adam.step()
adam.clear_grad() adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad() return self.user_defined_optimizer.clear_grad()
......
...@@ -637,7 +637,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -637,7 +637,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return "lo" return "lo"
def __start_kv_server(self, http_server_d, size_d): def __start_kv_server(self, http_server_d, size_d):
from paddle.distributed.fleet.utils import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(self._http_ip_port[1]), size_d) http_server = KVServer(int(self._http_ip_port[1]), size_d)
http_server.start() http_server.start()
wait_seconds = 5 wait_seconds = 5
...@@ -651,6 +651,7 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker): ...@@ -651,6 +651,7 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker):
def __init__(self, is_collective=False, init_gloo=False, **kwargs): def __init__(self, is_collective=False, init_gloo=False, **kwargs):
super(UserDefinedRoleMaker, self).__init__( super(UserDefinedRoleMaker, self).__init__(
is_collective=is_collective, init_gloo=init_gloo, **kwargs) is_collective=is_collective, init_gloo=init_gloo, **kwargs)
self._init_gloo = init_gloo
def _user_defined_ps_env(self): def _user_defined_ps_env(self):
self._server_endpoints = self._kwargs.get("server_endpoints") self._server_endpoints = self._kwargs.get("server_endpoints")
......
...@@ -16,20 +16,18 @@ ...@@ -16,20 +16,18 @@
"""basic collective operations in python""" """basic collective operations in python"""
"""remote file system""" """remote file system"""
__all__ = ['UtilBase']
import numpy as np
import os
import subprocess
from paddle.fluid import core
from collections import OrderedDict
import paddle.fluid as fluid
from google.protobuf import text_format
from paddle.fluid import debugger
from paddle.fluid.framework import Program
from paddle.fluid.proto import framework_pb2
from ..utils.fs import FS, LocalFS, HDFSClient from ..utils.fs import FS, LocalFS, HDFSClient
from paddle.fluid.proto import framework_pb2
from paddle.fluid.framework import Program
from paddle.fluid import debugger
from google.protobuf import text_format
import paddle.fluid as fluid
from collections import OrderedDict
from paddle.fluid import core
import subprocess
import os
import numpy as np
__all__ = ['UtilBase']
class UtilFactory(object): class UtilFactory(object):
...@@ -53,7 +51,7 @@ class UtilBase(object): ...@@ -53,7 +51,7 @@ class UtilBase(object):
def _set_role_maker(self, role_maker): def _set_role_maker(self, role_maker):
self.role_maker = role_maker self.role_maker = role_maker
def set_file_system(self, fs_client): def _set_file_system(self, fs_client):
assert isinstance( assert isinstance(
fs_client, FS fs_client, FS
), "fs_client must be the instance of paddle.distributed.fleet.utils.FS" ), "fs_client must be the instance of paddle.distributed.fleet.utils.FS"
...@@ -87,36 +85,183 @@ class UtilBase(object): ...@@ -87,36 +85,183 @@ class UtilBase(object):
return _comm_world return _comm_world
def all_reduce(self, input, mode, comm_world="worker"): def all_reduce(self, input, mode, comm_world="worker"):
"""
All reduce `input` between specified collection. This is a distributed API.
Args:
input (list|numpy.array): The input variable to do all_reduce between specified collection.
mode (str): "sum" or "min" or "max".
comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Returns:
output(Numpy.array|None): A numpy array with the same shape as the `input` .
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
import numpy as np
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
input = [1, 2]
output = fleet_util.all_reduce(input, "sum", "server")
print(output)
# [2, 4]
elif fleet.is_worker():
input = np.array([3, 4])
output = fleet_util.all_reduce(input, "sum", "worker")
print(output)
# [6, 8]
output = fleet_util.all_reduce(input, "sum", "all")
print(output)
# [8, 12]
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_reduce(_comm_world, input, mode) return self.role_maker._all_reduce(_comm_world, input, mode)
def barrier(self, comm_world="worker"): def barrier(self, comm_world="worker"):
"""
Barrier between specified collection.
Args:
comm_world (str, optional): Collection used to execute barrier operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
fleet_util.barrier("server")
print("all server arrive here")
elif fleet.is_worker():
fleet_util.barrier("worker")
print("all server arrive here")
fleet_util.barrier("all")
print("all servers and workers arrive here")
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
self.role_maker._barrier(_comm_world) self.role_maker._barrier(_comm_world)
def all_gather(self, input, comm_world="worker"): def all_gather(self, input, comm_world="worker"):
"""
All gather `input` between specified collection.
Args:
input (Int|Float): The input variable to do all_gather between specified collection.
comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Returns:
output (List): A list of gathered values.
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
input = fleet.server_index()
output = fleet_util.all_gather(input, "server")
print(output)
# output = [0, 1]
elif fleet.is_worker():
input = fleet.worker_index()
output = fleet_util.all_gather(input, "worker")
# output = [0, 1]
print(output)
output = fleet_util.all_gather(input, "all")
print(output)
# output = [0, 1, 0, 1]
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_gather(_comm_world, input) return self.role_maker._all_gather(_comm_world, input)
def broadcast(self): def _broadcast(self):
pass pass
def scatter(self): def _scatter(self):
pass pass
def get_file_shard(self, files): def get_file_shard(self, files):
""" """
split files before distributed training, Split files before distributed training, and return filelist assigned to the current trainer.
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e]. .. code-block:: text
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets [] example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args: Args:
files(list): file list need to be read. files(list): File list need to be read.
Returns: Returns:
list: files belongs to this worker. List: Files belong to this worker.
Examples:
.. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role)
files = fleet_util.get_file_shard(["file1", "file2", "file3"])
# files = ["file1", "file2"]
""" """
if not isinstance(files, list): if not isinstance(files, list):
raise TypeError("files should be a list of file need to be read.") raise TypeError("files should be a list of file need to be read.")
...@@ -140,6 +285,30 @@ class UtilBase(object): ...@@ -140,6 +285,30 @@ class UtilBase(object):
return trainer_files[trainer_id] return trainer_files[trainer_id]
def print_on_rank(self, message, rank_id): def print_on_rank(self, message, rank_id):
"""
Woker of rank `rank_id` print some message.
Args:
message(str): Log to be printed.
rank_id(int): trainer id.
Examples:
.. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role)
fleet_util.print_on_rank("I'm worker 0", 0)
"""
if self.role_maker.worker_index() != rank_id: if self.role_maker.worker_index() != rank_id:
return return
print(message) print(message)
...@@ -297,7 +466,7 @@ class UtilBase(object): ...@@ -297,7 +466,7 @@ class UtilBase(object):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
inference_program, feed_target_names, fetch_targets = \ inference_program, feed_target_names, fetch_targets = \
fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename, fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename,
params_filename=config.save_params_filename) params_filename=config.save_params_filename)
# check program vars and saved vars shape # check program vars and saved vars shape
orig_para_shape = { orig_para_shape = {
......
...@@ -87,7 +87,7 @@ def _parse_args(): ...@@ -87,7 +87,7 @@ def _parse_args():
see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2- see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
''') ''')
#Optional arguments for the launch helper # Optional arguments for the launch helper
parser.add_argument( parser.add_argument(
"--ips", "--ips",
type=str, type=str,
...@@ -115,7 +115,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -115,7 +115,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
default="log", default="log",
help="The path for each process's log.If it's not set, the log will printed to default pipe." help="The path for each process's log.If it's not set, the log will printed to default pipe."
) )
#positional # positional
parser.add_argument( parser.add_argument(
"training_script", "training_script",
type=str, type=str,
...@@ -124,7 +124,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -124,7 +124,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"followed by all the arguments for the " "followed by all the arguments for the "
"training script") "training script")
#rest from the training program # rest from the training program
parser.add_argument('training_script_args', nargs=REMAINDER) parser.add_argument('training_script_args', nargs=REMAINDER)
return parser.parse_args() return parser.parse_args()
...@@ -138,7 +138,7 @@ def get_cluster_from_args(args, gpus): ...@@ -138,7 +138,7 @@ def get_cluster_from_args(args, gpus):
# node_ip = args.node_ip # node_ip = args.node_ip
assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \ assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \
% (node_ip, node_ips) % (node_ip, node_ips)
node_rank = node_ips.index(node_ip) node_rank = node_ips.index(node_ip)
logger.debug("parsed from args: node_ips:{} node_ip:{} node_rank:{}".format( logger.debug("parsed from args: node_ips:{} node_ip:{} node_rank:{}".format(
...@@ -280,7 +280,7 @@ def launch_ps(args): ...@@ -280,7 +280,7 @@ def launch_ps(args):
_, current_node_ip = get_host_name_ip() _, current_node_ip = get_host_name_ip()
assert current_node_ip in node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ assert current_node_ip in node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (current_node_ip, node_ips) % (current_node_ip, node_ips)
node_rank = node_ips.index(current_node_ip) node_rank = node_ips.index(current_node_ip)
logger.debug( logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}". "parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}".
...@@ -323,10 +323,12 @@ def launch_ps(args): ...@@ -323,10 +323,12 @@ def launch_ps(args):
for idx, cur_server in enumerate(pod.servers): for idx, cur_server in enumerate(pod.servers):
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1], "PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER", "TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": cur_server.endpoint.split(":")[0] "POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -365,7 +367,8 @@ def launch_ps(args): ...@@ -365,7 +367,8 @@ def launch_ps(args):
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank) "PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -430,7 +433,11 @@ def launch(): ...@@ -430,7 +433,11 @@ def launch():
co_arg for co_arg in collective_args co_arg for co_arg in collective_args
if co_arg in " ".join(sys.argv[1:-1]) if co_arg in " ".join(sys.argv[1:-1])
] ]
cuda_device_num = fluid.core.get_cuda_device_count() if fluid.core.is_compiled_with_cuda():
cuda_device_num = fluid.core.get_cuda_device_count()
else:
cuda_device_num = 0
if len(has_ps_args) > 0 or cuda_device_num == 0: if len(has_ps_args) > 0 or cuda_device_num == 0:
logger.info( logger.info(
"Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}". "Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}".
......
...@@ -18,6 +18,7 @@ from .graph_execution_optimizer import GraphExecutionOptimizer ...@@ -18,6 +18,7 @@ from .graph_execution_optimizer import GraphExecutionOptimizer
from .parameter_server_optimizer import ParameterServerOptimizer from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer from .pipeline_optimizer import PipelineOptimizer
from .localsgd_optimizer import LocalSGDOptimizer from .localsgd_optimizer import LocalSGDOptimizer
from .localsgd_optimizer import AdaptiveLocalSGDOptimizer
from .lars_optimizer import LarsOptimizer from .lars_optimizer import LarsOptimizer
from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer from .dgc_optimizer import DGCOptimizer
......
...@@ -22,9 +22,13 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -22,9 +22,13 @@ class AMPOptimizer(MetaOptimizerBase):
self.amp_opt = None self.amp_opt = None
# we do not allow meta optimizer to be inner optimizer currently # we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [ self.meta_optimizers_white_list = [
"LarsOptimizer", "LambOptimizer", "RecomputeOptimizer", "LarsOptimizer",
"LocalSGDOptimizer", "GradientMergeOptimizer", "LambOptimizer",
"GraphExecutionOptimizer" "RecomputeOptimizer",
"LocalSGDOptimizer",
"GradientMergeOptimizer",
"GraphExecutionOptimizer",
"AdaptiveLocalSGDOptimizer",
] ]
self.meta_optimizers_black_list = ["DGCOptimizer"] self.meta_optimizers_black_list = ["DGCOptimizer"]
......
...@@ -25,7 +25,10 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -25,7 +25,10 @@ class LocalSGDOptimizer(MetaOptimizerBase):
super(LocalSGDOptimizer, self).__init__(optimizer) super(LocalSGDOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = ["GraphExecutionOptimizer"] self.meta_optimizers_black_list = [
"GraphExecutionOptimizer",
"AdaptiveLocalSGDOptimizer",
]
self.snapshot_key = '@SNAPSHOT' self.snapshot_key = '@SNAPSHOT'
def _can_apply(self): def _can_apply(self):
...@@ -186,3 +189,252 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -186,3 +189,252 @@ class LocalSGDOptimizer(MetaOptimizerBase):
layers.cond(step > begin_step, begin_localsgd, communicate) layers.cond(step > begin_step, begin_localsgd, communicate)
return minimized return minimized
class AdaptiveLocalSGDOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(AdaptiveLocalSGDOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = [
"GraphExecutionOptimizer", "LocalSGDOptimizer"
]
self.snapshot_key = '@SNAPSHOT'
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if not self.user_defined_strategy.adaptive_localsgd:
return False
if self.role_maker.worker_num() <= 1:
return False
return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \
or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
def _disable_strategy(self, dist_strategy):
dist_strategy.adaptive_localsgd = False
dist_strategy.adaptive_localsgd_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.adaptive_localsgd = True
dist_strategy.adaptive_localsgd_configs = {
"init_k_steps": 1,
"begin_step": 1
}
def snapshot_name(self, param_name):
return param_name + self.snapshot_key
def create_snapshot_vars(self, program):
block = program.global_block()
non_dist_params = []
for param in block.iter_parameters():
if not param.is_distributed:
non_dist_params.append(param)
p2s = []
for param in non_dist_params:
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True,
dtype=param.dtype)
p2s.append([param, snapshot])
return p2s
def init_snapshot_vars(self, startup_program, param2snapshot):
with program_guard(startup_program):
for param, snapshot in param2snapshot:
layers.assign(param, snapshot)
def _generate_avg_loss(self, program_block, loss, avg_loss):
program_block.append_op(
type='c_allreduce_sum',
inputs={'X': [loss]},
outputs={'Out': [avg_loss]},
attrs={
'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize,
'use_calc_stream': True
})
program_block.append_op(
type='c_sync_calc_stream',
inputs={'X': [avg_loss]},
outputs={'Out': [avg_loss]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
program_block.append_op(
type='scale',
inputs={'X': [avg_loss]},
outputs={'Out': [avg_loss]},
attrs={
'scale': 1.0 / self.role_maker.worker_num(),
OP_ROLE_KEY: OpRole.Optimize
})
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
minimized = self.inner_opt.minimize(
loss, startup_program=startup_program)
init_k_steps = self.user_defined_strategy.adaptive_localsgd_configs[
'init_k_steps']
begin_step_value = self.user_defined_strategy.adaptive_localsgd_configs[
'begin_step']
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
self.nrings = 2
collective_helper = CollectiveHelper(self.role_maker, self.nrings)
collective_helper.update_startup_program(startup_program)
p2s = self.create_snapshot_vars(startup_program)
self.init_snapshot_vars(startup_program, p2s)
p2s = self.create_snapshot_vars(main_block.program)
with program_guard(main_block.program, startup_program):
step = layers.autoincreased_step_counter(begin=1)
k_steps = layers.create_global_var(
name="k_steps",
shape=[1],
value=int(init_k_steps),
dtype='int64',
persistable=True)
begin_step = layers.create_global_var(
name="begin_step",
shape=[1],
value=int(begin_step_value),
dtype='int64',
persistable=True)
last_step = layers.create_global_var(
name="last_step",
shape=[1],
value=int(0),
dtype='int64',
persistable=True)
avg_loss = layers.create_global_var(
name="avg_loss",
shape=[1],
value=float(0),
dtype=loss.dtype,
persistable=True)
lr_0 = layers.create_global_var(
name="lr_0",
shape=[1],
value=float(0),
dtype='float32',
persistable=True)
loss_0 = layers.create_global_var(
name="loss_0",
shape=[1],
value=float(0),
dtype='float32',
persistable=True)
global_lr = self.inner_opt._global_learning_rate()
def initialize():
self._generate_avg_loss(main_block, loss, avg_loss)
layers.assign(avg_loss, loss_0)
layers.assign(global_lr, lr_0)
layers.cond(step == 1, initialize)
def communicate():
sub_block = default_main_program().current_block()
ring_id = -1
for param, snapshot in p2s:
sub_block.append_op(
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
sub_block.append_op(
type='c_sync_calc_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={OP_ROLE_KEY: OpRole.Optimize})
ring_id = (ring_id + 1) % self.nrings
sub_block.append_op(
type='c_allreduce_sum',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
for ring_id in range(self.nrings):
sub_block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
for param, snapshot in p2s:
sub_block.append_op(
type='scale',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'scale': 1.0 / self.role_maker.worker_num(),
OP_ROLE_KEY: OpRole.Optimize
})
sub_block.append_op(
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
sub_block.append_op(
type='assign',
inputs={'X': [param]},
outputs={'Out': [snapshot]},
attrs={OP_ROLE_KEY: OpRole.Optimize})
layers.assign(step, last_step)
def communicate_avg_loss():
communicate()
self._generate_avg_loss(main_block, loss, avg_loss)
next_local_steps = layers.cast(
layers.ceil(
layers.sqrt(lr_0 * avg_loss / (global_lr * loss_0) *
float(init_k_steps))),
dtype='int64')
max_local_steps = layers.fill_constant(
shape=[1], dtype='int64', value=16)
min_local_steps = layers.fill_constant(
shape=[1], dtype='int64', value=1)
next_local_steps = layers.elementwise_min(next_local_steps,
max_local_steps)
next_local_steps = layers.elementwise_max(next_local_steps,
min_local_steps)
layers.assign(next_local_steps, k_steps)
def begin_localsgd():
layers.cond(step - last_step == k_steps, communicate_avg_loss)
layers.cond(step > begin_step, begin_localsgd, communicate)
return minimized
...@@ -11,8 +11,3 @@ ...@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .fs import *
from .http_server import KVHandler, KVHTTPServer, KVServer
#__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__
...@@ -32,10 +32,7 @@ import functools ...@@ -32,10 +32,7 @@ import functools
from pathlib import PurePosixPath, Path from pathlib import PurePosixPath, Path
import shutil import shutil
__all__ = [ __all__ = ['LocalFS', 'HDFSClient']
'FS', 'LocalFS', 'HDFSClient', 'ExecuteError', 'FSTimeOut',
'FSFileExistsError', 'FSFileNotExistsError', 'FSShellCmdAborted'
]
class ExecuteError(Exception): class ExecuteError(Exception):
...@@ -117,7 +114,37 @@ class FS(object): ...@@ -117,7 +114,37 @@ class FS(object):
class LocalFS(FS): class LocalFS(FS):
"""
A tool of local file system.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs, files = client.ls_dir("./")
"""
def ls_dir(self, fs_path): def ls_dir(self, fs_path):
"""
List directorys and files under `fs_path` .
Args:
fs_path(str): The local file path.
Returns:
Tuple: Return a 2-tuple, the first is a list of all its subdirectories,
and the second is a list of all its subfiles, e.g. ([subdirname1, subdirname1, ...], [filename1, filename2, ...]).
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs, files = client.ls_dir("./")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [], [] return [], []
...@@ -132,11 +159,46 @@ class LocalFS(FS): ...@@ -132,11 +159,46 @@ class LocalFS(FS):
return dirs, files return dirs, files
def mkdirs(self, fs_path): def mkdirs(self, fs_path):
"""
Create a remote HDFS directory.
Args:
fs_path(str): The local directory path.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_mkdirs")
client.delete("test_mkdirs")
"""
assert not os.path.isfile(fs_path), "{} is already a file".format( assert not os.path.isfile(fs_path), "{} is already a file".format(
fs_path) fs_path)
os.system("mkdir -p {}".format(fs_path)) os.system("mkdir -p {}".format(fs_path))
def rename(self, fs_src_path, fs_dst_path): def rename(self, fs_src_path, fs_dst_path):
"""
Rename the file.
Args:
fs_src_path(str): The actual name of the file or directory
fs_dst_path(str): The new name of the file or directory.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_rename_src")
print(client.is_exists("test_rename_src")) # True
client.rename("test_rename_src", "test_rename_dst")
print(client.is_exists("test_rename_src")) # False
print(client.is_exists("test_rename_dst")) # True
client.delete("test_rename_dst")
"""
os.rename(fs_src_path, fs_dst_path) os.rename(fs_src_path, fs_dst_path)
def _rmr(self, fs_path): def _rmr(self, fs_path):
...@@ -146,6 +208,21 @@ class LocalFS(FS): ...@@ -146,6 +208,21 @@ class LocalFS(FS):
os.remove(fs_path) os.remove(fs_path)
def delete(self, fs_path): def delete(self, fs_path):
"""
Delete the local file path, whether it's a file or directory.
Args:
fs_path(str): The local file path.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_localFS_mkdirs")
client.delete("test_localFS_mkdirs")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return return
...@@ -158,15 +235,88 @@ class LocalFS(FS): ...@@ -158,15 +235,88 @@ class LocalFS(FS):
return False return False
def is_file(self, fs_path): def is_file(self, fs_path):
"""
Whether the local file path is a file.
Args:
fs_path(str): The local file path.
Returns:
Bool: Return true if the path exists and it's a file, otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_is_file")
print(client.is_file("test_is_file")) # True
client.delete("test_is_file")
"""
return os.path.isfile(fs_path) return os.path.isfile(fs_path)
def is_dir(self, fs_path): def is_dir(self, fs_path):
"""
Whether the local file path is a directory.
Args:
fs_path(str): The local file path.
Returns:
Bool: Return true if the path exists and it's a directory, otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_is_dir")
print(client.is_dir("test_is_file")) # True
client.delete("test_is_dir")
"""
return os.path.isdir(fs_path) return os.path.isdir(fs_path)
def is_exist(self, fs_path): def is_exist(self, fs_path):
"""
Whether the local file path exists.
Args:
fs_path(str): The local file path.
Returns:
Bool: Wheter it's a file or directory, return true if the path exists,
otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
ret = local_fs.is_exist("test_is_exist")
"""
return os.path.exists(fs_path) return os.path.exists(fs_path)
def touch(self, fs_path, exist_ok=True): def touch(self, fs_path, exist_ok=True):
"""
Create a local file.
Args:
fs_path(str): The local file path.
exist_ok(bool): When `fs_path` exists, if `exist_ok` is set false,
program will throw an Exception. Default is true.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_touch")
client.delete("test_touch")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
if exist_ok: if exist_ok:
return return
...@@ -175,6 +325,26 @@ class LocalFS(FS): ...@@ -175,6 +325,26 @@ class LocalFS(FS):
return Path(fs_path).touch(exist_ok=True) return Path(fs_path).touch(exist_ok=True)
def mv(self, src_path, dst_path, overwrite=False, test_exists=False): def mv(self, src_path, dst_path, overwrite=False, test_exists=False):
"""
Move a local file or directory from `src_path` to `dst_path` .
Args:
src_path(str): Name of the file or directory, that's needed to be moved.
dst_path(str): Name of the file or directory to which to move to.
overwrite(bool): Whether to re-write `dst_path` if that exists. Default is False.
test_exists(bool): Check the existence of `src_path` and `dst_path` .
When `test_exists` is set true, if `src_path` doesn't exist or `dst_path` exists, program will throw an Excetption.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_mv_src")
client.mv("test_mv_src", "test_mv_dst")
client.delete("test_mv_dst")
"""
if not self.is_exist(src_path): if not self.is_exist(src_path):
raise FSFileNotExistsError raise FSFileNotExistsError
...@@ -188,7 +358,21 @@ class LocalFS(FS): ...@@ -188,7 +358,21 @@ class LocalFS(FS):
def list_dirs(self, fs_path): def list_dirs(self, fs_path):
""" """
list directory under fs_path, and only give the pure name, not include the fs_path Only list directorys under `fs_path` .
Args:
fs_path(str): The local file path.
Returns:
List: A list of all its subdirectories, e.g. [subdirname1, subdirname1, ...].
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs = client.list_dirs("./")
""" """
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [] return []
...@@ -200,26 +384,6 @@ class LocalFS(FS): ...@@ -200,26 +384,6 @@ class LocalFS(FS):
return dirs return dirs
"""HDFS Utils."""
def _handle_errors(f):
def handler(*args, **kwargs):
start = time.time()
while True:
try:
return f(*args, **kwargs)
except ExecuteError as e:
o = args[0]
time_out = float(o._time_out) / 1000.0
inter = float(o._sleep_inter) / 1000.0
if time.time() - start >= time_out:
raise FSTimeOut
time.sleep(inter)
return functools.wraps(f)(handler)
def _handle_errors(max_time_out=None): def _handle_errors(max_time_out=None):
def decorator(f): def decorator(f):
@functools.wraps(f) @functools.wraps(f)
...@@ -237,7 +401,7 @@ def _handle_errors(max_time_out=None): ...@@ -237,7 +401,7 @@ def _handle_errors(max_time_out=None):
while True: while True:
try: try:
return f(*args, **kwargs) return f(*args, **kwargs)
#important: only ExecuteError need to retry # important: only ExecuteError need to retry
except ExecuteError as e: except ExecuteError as e:
if time.time() - start >= time_out: if time.time() - start >= time_out:
raise FSTimeOut("args:{} timeout:{}".format( raise FSTimeOut("args:{} timeout:{}".format(
...@@ -256,12 +420,36 @@ def _handle_errors(max_time_out=None): ...@@ -256,12 +420,36 @@ def _handle_errors(max_time_out=None):
class HDFSClient(FS): class HDFSClient(FS):
"""
A tool of HDFS.
Args:
hadoop_home(str): Hadoop home.
configs(dict): Hadoop config. It is a dictionary and needs to contain the
keys: "fs.default.name" and "hadoop.job.ugi".
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.ls_dir("hdfs:/test_hdfs_client")
"""
def __init__( def __init__(
self, self,
hadoop_home, hadoop_home,
configs, configs,
time_out=5 * 60 * 1000, #ms time_out=5 * 60 * 1000, # ms
sleep_inter=1000): #ms sleep_inter=1000): # ms
# Raise exception if JAVA_HOME not exists. # Raise exception if JAVA_HOME not exists.
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
...@@ -292,6 +480,30 @@ class HDFSClient(FS): ...@@ -292,6 +480,30 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def list_dirs(self, fs_path): def list_dirs(self, fs_path):
"""
Only list directorys under `fs_path` .
Args:
fs_path(str): The HDFS file path.
Returns:
List: A list of all its subdirectories, e.g. [subdirname1, subdirname1, ...].
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
subdirs = client.list_dirs("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [] return []
...@@ -301,7 +513,29 @@ class HDFSClient(FS): ...@@ -301,7 +513,29 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def ls_dir(self, fs_path): def ls_dir(self, fs_path):
""" """
list directory under fs_path, and only give the pure name, not include the fs_path List directorys and files under `fs_path` .
Args:
fs_path(str): The HDFS file path.
Returns:
Tuple: Return a 2-tuple, the first element is the list of all its subdirectories,
and the second one is the list of all its subfiles, e.g. ([subdirname1, subdirname1, ...], [filename1, filename2, ...]).
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
subdirs, files = client.ls_dir("hdfs:/test_hdfs_client")
""" """
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [], [] return [], []
...@@ -340,6 +574,30 @@ class HDFSClient(FS): ...@@ -340,6 +574,30 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def is_dir(self, fs_path): def is_dir(self, fs_path):
"""
Whether the remote HDFS path is a directory.
Args:
fs_path(str): The HDFS file path.
Returns:
Bool: Return true if the path exists and it's a directory, otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_file("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return False return False
...@@ -358,6 +616,30 @@ class HDFSClient(FS): ...@@ -358,6 +616,30 @@ class HDFSClient(FS):
return True return True
def is_file(self, fs_path): def is_file(self, fs_path):
"""
Whether the remote HDFS path is a file.
Args:
fs_path(str): The HDFS file path.
Returns:
Bool: Return true if the path exists and it's a file, otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_file("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return False return False
...@@ -365,6 +647,31 @@ class HDFSClient(FS): ...@@ -365,6 +647,31 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def is_exist(self, fs_path): def is_exist(self, fs_path):
"""
Whether the remote HDFS path exists.
Args:
fs_path(str): The hdfs file path.
Returns:
Bool: Whether it's is file or directory, return true if the path exists,
otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_exist("hdfs:/test_hdfs_client")
"""
cmd = "ls {} ".format(fs_path) cmd = "ls {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True) ret, out = self._run_cmd(cmd, redirect_stderr=True)
if ret != 0: if ret != 0:
...@@ -377,6 +684,28 @@ class HDFSClient(FS): ...@@ -377,6 +684,28 @@ class HDFSClient(FS):
# can't retry # can't retry
def upload(self, local_path, fs_path): def upload(self, local_path, fs_path):
"""
Upload the local path to remote HDFS.
Args:
local_path(str): The local path.
fs_path(str): The HDFS path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.upload("test_hdfs_client", "hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
raise FSFileExistsError("{} exists".format(fs_path)) raise FSFileExistsError("{} exists".format(fs_path))
...@@ -400,6 +729,28 @@ class HDFSClient(FS): ...@@ -400,6 +729,28 @@ class HDFSClient(FS):
# can't retry # can't retry
def download(self, fs_path, local_path): def download(self, fs_path, local_path):
"""
Download remote HDFS path to the local.
Args:
fs_path(str): The HDFS path.
local_path(str): The local path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.download("hdfs:/test_hdfs_client", "./")
"""
if self.is_exist(local_path): if self.is_exist(local_path):
raise FSFileExistsError("{} exists".format(local_path)) raise FSFileExistsError("{} exists".format(local_path))
...@@ -423,6 +774,27 @@ class HDFSClient(FS): ...@@ -423,6 +774,27 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def mkdirs(self, fs_path): def mkdirs(self, fs_path):
"""
Create a remote HDFS directory.
Args:
fs_path(str): The HDFS directory path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.mkdirs("hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
return return
...@@ -445,6 +817,30 @@ class HDFSClient(FS): ...@@ -445,6 +817,30 @@ class HDFSClient(FS):
raise ExecuteError(cmd) raise ExecuteError(cmd)
def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True): def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True):
"""
Move a remote HDFS file or directory from `fs_src_path` to `fs_dst_path` .
Args:
fs_src_path(str): Name of the file or directory, that's needed to be moved.
fs_dst_path(str): Name of the file or directory to which to move to.
overwrite(bool): Whether to re-write `fs_dst_path` if that exists. Default is False.
test_exists(bool): Check the existence of `fs_src_path` and `fs_dst_path` . When `test_exists` is set true, if `fs_src_path` doesn't exist or `fs_dst_path` exists, program will throw an Excetption.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2")
"""
if overwrite and self.is_exist(fs_dst_path): if overwrite and self.is_exist(fs_dst_path):
self.delete(fs_dst_path) self.delete(fs_dst_path)
...@@ -487,6 +883,27 @@ class HDFSClient(FS): ...@@ -487,6 +883,27 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def delete(self, fs_path): def delete(self, fs_path):
"""
Delete a remote HDFS path, whether it's a file or directory.
Args:
fs_path(str): The HDFS file path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.delete("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return return
...@@ -497,6 +914,27 @@ class HDFSClient(FS): ...@@ -497,6 +914,27 @@ class HDFSClient(FS):
return self._rm(fs_path) return self._rm(fs_path)
def touch(self, fs_path, exist_ok=True): def touch(self, fs_path, exist_ok=True):
"""
Create a remote HDFS file.
Args:
fs_path(str): The HDFS file path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.touch("hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
if exist_ok: if exist_ok:
return return
......
...@@ -39,6 +39,11 @@ try: ...@@ -39,6 +39,11 @@ try:
third_lib_path = current_path + os.sep + '..' + os.sep + 'libs' third_lib_path = current_path + os.sep + '..' + os.sep + 'libs'
os.environ['path'] = third_lib_path + ';' + os.environ['path'] os.environ['path'] = third_lib_path + ';' + os.environ['path']
sys.path.insert(0, third_lib_path) sys.path.insert(0, third_lib_path)
# Note: from python3.8, PATH will not take effect
# https://github.com/python/cpython/pull/12302
# Use add_dll_directory to specify dll resolution path
if sys.version_info[:2] >= (3, 8):
os.add_dll_directory(third_lib_path)
except ImportError as e: except ImportError as e:
from .. import compat as cpt from .. import compat as cpt
......
...@@ -23,7 +23,6 @@ from paddle.fluid import framework ...@@ -23,7 +23,6 @@ from paddle.fluid import framework
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer from .tracer import Tracer
import logging import logging
import objgraph
from ..data_feeder import convert_dtype from ..data_feeder import convert_dtype
import warnings import warnings
...@@ -368,24 +367,6 @@ def guard(place=None): ...@@ -368,24 +367,6 @@ def guard(place=None):
yield yield
def _print_debug_msg(parameter_list, limit=5, is_test=False):
if not core._is_dygraph_debug_enabled():
logging.warn(
'Debug mode is not enabled. Please set FLAGS_dygraph_debug=1 to enable debug'
)
return
unique_name_size = len(framework.unique_name.generator.ids)
tracer_var_size = len(parameter_list)
alive_cpp_var_size = len(core.VarBase._alive_vars())
if not is_test:
logging.warn(
'unique_name num: {}, tracer vars num: {}, alive cpp vars num: {}'
.format(unique_name_size, tracer_var_size, alive_cpp_var_size))
objgraph.show_growth(limit=limit)
else:
return unique_name_size, tracer_var_size, alive_cpp_var_size
@framework.dygraph_only @framework.dygraph_only
def grad(outputs, def grad(outputs,
inputs, inputs,
......
...@@ -195,58 +195,11 @@ def load_dygraph(model_path, config=None): ...@@ -195,58 +195,11 @@ def load_dygraph(model_path, config=None):
params_file_path = model_prefix + ".pdparams" params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt" opti_file_path = model_prefix + ".pdopt"
# deal with argument `configs` # deal with argument `config`
configs = config if config is None:
if configs is None: config = SaveLoadConfig()
configs = SaveLoadConfig()
if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save/io.save_inference_model` save format
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state
# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load program desc & construct _ProgramHolder if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
programs = _construct_program_holders(model_path,
configs.model_filename)
# 3. load layer parameters & buffers
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
with guard():
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
configs.separate_params,
configs.params_filename,
append_suffix=False)
# 4. construct state_dict
para_dict = dict()
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
structured_para_dict = dict()
for var_name in para_dict:
structured_name = extra_var_info[var_name].get(
'structured_name', None)
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
structured_para_dict[structured_name] = para_dict[var_name]
para_dict = structured_para_dict
else:
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
para_dict = {} para_dict = {}
if os.path.exists(params_file_path): if os.path.exists(params_file_path):
...@@ -254,12 +207,103 @@ def load_dygraph(model_path, config=None): ...@@ -254,12 +207,103 @@ def load_dygraph(model_path, config=None):
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict: if not config.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f: with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
else:
# check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# check whether model file exists
if config.model_filename is None:
model_filename = '__model__'
else:
model_filename = config.model_filename
model_file_path = os.path.join(model_path, model_filename)
if os.path.exists(model_file_path):
# Load state dict by `jit.save/io.save_inference_model` save format
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state
# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path,
config.model_filename)
# 2. load layer parameters & buffers
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
with guard():
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
config.separate_params,
config.params_filename,
append_suffix=False)
# 3. construct state_dict
para_dict = dict()
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name
var_info_path = os.path.join(model_prefix,
EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
structured_para_dict = dict()
for var_name in para_dict:
structured_name = extra_var_info[var_name].get(
'structured_name', None)
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
structured_para_dict[structured_name] = para_dict[
var_name]
para_dict = structured_para_dict
else:
# load state dict by `io.save_params/persistables` save format
# TODO(chenweihang): [ Now only supports loading parameters seperately ]
# If users save all parameters as one file, the [ variable.name -> variable ]
# mapping info will lost, so users need to give variable list, but users build
# variable list in dygraph mode is difficult, we recommend users to use
# paddle.io.load_program_state in this case
# Try to load all the files in the directory in VarBase format,
# the file name is used as the name of VarBase
load_var_list = []
# 1. load file names
var_name_list = []
for root, _, files in os.walk(model_path):
for filename in files:
file_path = os.path.join(root, filename)
tmp_var_name = os.path.relpath(file_path, model_path)
var_name = tmp_var_name.replace("\\", "/")
var_name_list.append(var_name)
# 2. create and load VarBase
with guard():
for name in var_name_list:
new_var = _varbase_creator(name=name, persistable=True)
_dygraph_tracer().trace_op(
type='load',
inputs={},
outputs={'Out': new_var},
attrs={'file_path': os.path.join(model_path, name)})
load_var_list.append(new_var)
# 3. construct state_dict
para_dict = dict()
for var in load_var_list:
para_dict[var.name] = var.numpy()
return para_dict, opti_dict return para_dict, opti_dict
...@@ -246,7 +246,7 @@ class StaticLayer(object): ...@@ -246,7 +246,7 @@ class StaticLayer(object):
self._function_spec = FunctionSpec(function, input_spec) self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary() self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
def __get__(self, instance, owner): def __get__(self, instance, owner):
...@@ -299,16 +299,17 @@ class StaticLayer(object): ...@@ -299,16 +299,17 @@ class StaticLayer(object):
""" """
# 1. call dygraph function directly if not enable `declarative` # 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative: if not self._program_trans.enable_to_static:
logging_utils.warn( logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. " "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output.") "We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
return self._call_dygraph_function(*args, **kwargs) return self._call_dygraph_function(*args, **kwargs)
if not in_dygraph_mode() and self._program_trans.enable_declarative: if not in_dygraph_mode():
raise RuntimeError( raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', " "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the " "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format( "following API: paddle.disable_static().".format(
self.dygraph_function)) self.dygraph_function))
...@@ -723,15 +724,15 @@ class ProgramTranslator(object): ...@@ -723,15 +724,15 @@ class ProgramTranslator(object):
return return
self._initialized = True self._initialized = True
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self.enable_declarative = True self.enable_to_static = True
def enable(self, enable_declarative): def enable(self, enable_to_static):
""" """
Enable or disable the converting from imperative to declarative by Enable or disable the converting from imperative to declarative by
ProgramTranslator globally. ProgramTranslator globally.
Args: Args:
enable_declarative (bool): True or False to enable or disable declarative. enable_to_static (bool): True or False to enable or disable declarative.
Returns: Returns:
None. None.
...@@ -760,9 +761,9 @@ class ProgramTranslator(object): ...@@ -760,9 +761,9 @@ class ProgramTranslator(object):
print(func(x).numpy()) # [[2. 2.]] print(func(x).numpy()) # [[2. 2.]]
""" """
check_type(enable_declarative, "enable_declarative", bool, check_type(enable_to_static, "enable_to_static", bool,
"ProgramTranslator.enable") "ProgramTranslator.enable")
self.enable_declarative = enable_declarative self.enable_to_static = enable_to_static
def get_output(self, dygraph_func, *args, **kwargs): def get_output(self, dygraph_func, *args, **kwargs):
""" """
...@@ -803,10 +804,12 @@ class ProgramTranslator(object): ...@@ -803,10 +804,12 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output" ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output.") "We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
try: try:
function_spec = FunctionSpec(dygraph_func) function_spec = FunctionSpec(dygraph_func)
...@@ -876,10 +879,11 @@ class ProgramTranslator(object): ...@@ -876,10 +879,11 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func" ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will " "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output.") "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func return dygraph_func
static_func = convert_to_static(dygraph_func) static_func = convert_to_static(dygraph_func)
...@@ -929,10 +933,12 @@ class ProgramTranslator(object): ...@@ -929,10 +933,12 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program" ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False." "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output.") "We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func) function_spec = FunctionSpec(dygraph_func)
......
...@@ -119,7 +119,7 @@ def _dygraph_to_static_func_(dygraph_func): ...@@ -119,7 +119,7 @@ def _dygraph_to_static_func_(dygraph_func):
# TODO: remove this decorator after we finalize training API # TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative: if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in " "The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. " "dygraph mode or set ProgramTranslator.enable to False. "
...@@ -832,9 +832,9 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -832,9 +832,9 @@ def save(layer, model_path, input_spec=None, config=None):
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
if not prog_translator.enable: if not prog_translator.enable_to_static:
raise RuntimeError( raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable=False." "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer): if not isinstance(layer, Layer):
raise TypeError( raise TypeError(
......
...@@ -98,7 +98,7 @@ class AutoCheckpointChecker(object): ...@@ -98,7 +98,7 @@ class AutoCheckpointChecker(object):
self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache") self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache")
self._save_checkpoint_inter = int( self._save_checkpoint_inter = int(
os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) #s os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) # s
if not self._ce_test: if not self._ce_test:
assert len(self._hdfs_home) > 3 and \ assert len(self._hdfs_home) > 3 and \
...@@ -132,7 +132,7 @@ class AutoCheckpointChecker(object): ...@@ -132,7 +132,7 @@ class AutoCheckpointChecker(object):
if in_dygraph_mode(): if in_dygraph_mode():
return False return False
return self._run_env is not None and \ return self._run_env is not None and \
self._platform is not None and \ self._platform is not None and \
self._job_id is not None and \ self._job_id is not None and \
self._hdfs_home is not None and \ self._hdfs_home is not None and \
......
...@@ -26,8 +26,7 @@ import paddle.fluid as fluid ...@@ -26,8 +26,7 @@ import paddle.fluid as fluid
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_pslib from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_pslib
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler
from . import hdfs from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient
from .hdfs import *
from . import utils from . import utils
__all__ = ["FleetUtil"] __all__ = ["FleetUtil"]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
import multiprocessing
from datetime import datetime
import re
import copy
import errno
import time
import logging
import abc
from pathlib import PurePosixPath, Path
import shutil
__all__ = ['FS', 'LocalFS']
class ExecuteError(Exception):
pass
class FSFileExistsError(Exception):
pass
class FSFileNotExistsError(Exception):
pass
class FSTimeOut(Exception):
pass
class FSShellCmdAborted(ExecuteError):
pass
class FS(object):
@abc.abstractmethod
def ls_dir(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def is_file(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def is_dir(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def is_exist(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def upload(self, local_path, fs_path):
raise NotImplementedError
@abc.abstractmethod
def download(self, fs_path, local_path):
raise NotImplementedError
@abc.abstractmethod
def mkdirs(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def delete(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def need_upload_download(self):
raise NotImplementedError
@abc.abstractmethod
def rename(self, fs_src_path, fs_dst_path):
raise NotImplementedError
@abc.abstractmethod
def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=False):
raise NotImplementedError
@abc.abstractmethod
def upload_dir(self, local_dir, dest_dir):
raise NotImplementedError
@abc.abstractmethod
def list_dirs(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def touch(self, fs_path, exist_ok=True):
raise NotImplementedError
class LocalFS(FS):
def ls_dir(self, fs_path):
return [f for f in os.listdir(fs_path)]
def mkdirs(self, fs_path):
assert not os.path.isfile(fs_path), "{} is already a file".format(
fs_path)
os.system("mkdir -p {}".format(fs_path))
def rename(self, fs_src_path, fs_dst_path):
os.rename(fs_src_path, fs_dst_path)
def _rmr(self, fs_path):
shutil.rmtree(fs_path)
def _rm(self, fs_path):
os.remove(fs_path)
def delete(self, fs_path):
if not self.is_exist(fs_path):
return
if os.path.isfile(fs_path):
return self._rm(fs_path)
return self._rmr(fs_path)
def need_upload_download(self):
return False
def is_file(self, fs_path):
return os.path.isfile(fs_path)
def is_dir(self, fs_path):
return os.path.isdir(fs_path)
def is_exist(self, fs_path):
return os.path.exists(fs_path)
def touch(self, fs_path, exist_ok=True):
if self.is_exist(fs_path):
if exist_ok:
return
raise FSFileExistsError
return Path(fs_path).touch(exist_ok=True)
def mv(self, src_path, dst_path, overwrite=False, test_exists=False):
if not self.is_exist(src_path):
raise FSFileNotExistsError
if overwrite and self.is_exist(dst_path):
self.delete(dst_path)
if self.is_exist(dst_path):
raise FSFileExistsError
return self.rename(src_path, dst_path)
def list_dirs(self, fs_path):
"""
list directory under fs_path, and only give the pure name, not include the fs_path
"""
if not self.is_exist(fs_path):
return []
dirs = [
f for f in os.listdir(fs_path) if os.path.isdir(fs_path + "/" + f)
]
return dirs
...@@ -3570,8 +3570,10 @@ class ExponentialMovingAverage(object): ...@@ -3570,8 +3570,10 @@ class ExponentialMovingAverage(object):
# bias correction # bias correction
with layers.control_flow.Switch() as switch: with layers.control_flow.Switch() as switch:
with switch.case(global_step > 0): with switch.case(global_step > 0):
layers.assign(output=ema, input=ema / (1.0 - decay_pow)) layers.assign(
layers.assign(input=ema, output=param) output=param, input=ema / (1.0 - decay_pow))
with switch.default():
layers.assign(output=param, input=ema)
self.restore_program = Program() self.restore_program = Program()
block = self.restore_program.global_block() block = self.restore_program.global_block()
......
...@@ -4,6 +4,7 @@ set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0 FLAGS_fast_eager_deletion_mode=1 FL ...@@ -4,6 +4,7 @@ set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0 FLAGS_fast_eager_deletion_mode=1 FL
set(dist_ENVS http_proxy="" https_proxy="") set(dist_ENVS http_proxy="" https_proxy="")
file(GLOB DIST_TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_dist_*.py") file(GLOB DIST_TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_dist_*.py")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_op")
if(NOT WITH_NCCL) if(NOT WITH_NCCL)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl")
endif() endif()
...@@ -102,7 +103,6 @@ if(WIN32) ...@@ -102,7 +103,6 @@ if(WIN32)
endif() endif()
LIST(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_new)
LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint) LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint)
LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint1) LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint1)
LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint2) LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint2)
...@@ -326,7 +326,6 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_api) ...@@ -326,7 +326,6 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_api)
list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
list(REMOVE_ITEM TEST_OPS test_imperative_debug_string)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while)
...@@ -416,7 +415,6 @@ py_test_modules(test_imperative_ocr_attention_model MODULES test_imperative_ocr_ ...@@ -416,7 +415,6 @@ py_test_modules(test_imperative_ocr_attention_model MODULES test_imperative_ocr_
py_test_modules(test_install_check MODULES test_install_check ENVS py_test_modules(test_install_check MODULES test_install_check ENVS
FLAGS_cudnn_deterministic=1 SERIAL) FLAGS_cudnn_deterministic=1 SERIAL)
set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST") set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST")
py_test_modules(test_imperative_debug_string MODULES test_imperative_debug_string ENVS FLAGS_dygraph_debug=1)
py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_static_runner_mnist ENVS py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_static_runner_mnist ENVS
FLAGS_cudnn_deterministic=1) FLAGS_cudnn_deterministic=1)
py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS
...@@ -465,8 +463,8 @@ if(WITH_DISTRIBUTE) ...@@ -465,8 +463,8 @@ if(WITH_DISTRIBUTE)
#py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS}) #py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS})
if(NOT WIN32) if(NOT WIN32)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
#py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS})
#py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
endif(NOT WIN32) endif(NOT WIN32)
endif(NOT APPLE) endif(NOT APPLE)
if(WITH_DGC) if(WITH_DGC)
...@@ -560,7 +558,7 @@ endif() ...@@ -560,7 +558,7 @@ endif()
set_tests_properties(test_parallel_executor_test_while_train test_parallel_executor_mnist set_tests_properties(test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_feed_persistable_var test_parallel_executor_feed_persistable_var
test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_data_norm_op
test_dataloader_keep_order test_dataloader_keep_order
test_dataloader_unkeep_order test_dataloader_unkeep_order
test_parallel_executor_fetch_isolated_var test_parallel_executor_fetch_isolated_var
......
...@@ -20,8 +20,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -20,8 +20,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
from paddle.fluid.framework import program_guard from paddle.fluid.framework import program_guard
......
...@@ -19,7 +19,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,7 +19,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
......
...@@ -67,13 +67,13 @@ class AutoCheckpointTestDist(AutoCheckPointACLBase): ...@@ -67,13 +67,13 @@ class AutoCheckpointTestDist(AutoCheckPointACLBase):
save_dir = "./run_save_0" save_dir = "./run_save_0"
fs.delete(save_dir) fs.delete(save_dir)
#basic # basic
exe, main_prog, startup_prog = self._generate() exe, main_prog, startup_prog = self._generate()
compiled, data_loader, optimizer, loss, image, label = \ compiled, data_loader, optimizer, loss, image, label = \
self._init_env(exe, main_prog, startup_prog, minimize=False) self._init_env(exe, main_prog, startup_prog, minimize=False)
#fleet # fleet
os.environ["TRAINING_ROLE"] = "TRAINER" os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ID"] = "0" os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:6070" os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:6070"
......
...@@ -21,8 +21,7 @@ from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver ...@@ -21,8 +21,7 @@ from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
import os import os
import sys import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
......
...@@ -33,6 +33,14 @@ def execute(main_program, startup_program): ...@@ -33,6 +33,14 @@ def execute(main_program, startup_program):
exe.run(main_program) exe.run(main_program)
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
class TestDeviceGuard(unittest.TestCase): class TestDeviceGuard(unittest.TestCase):
def test_device_guard(self): def test_device_guard(self):
main_program = fluid.Program() main_program = fluid.Program()
...@@ -133,7 +141,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -133,7 +141,10 @@ class TestDeviceGuard(unittest.TestCase):
i = fluid.layers.increment(x=i, value=1, in_place=True) i = fluid.layers.increment(x=i, value=1, in_place=True)
fluid.layers.less_than(x=i, y=loop_len, cond=cond) fluid.layers.less_than(x=i, y=loop_len, cond=cond)
assert len(w) == 1 warning = "The Op(while) is not support to set device."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 1
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
...@@ -169,7 +180,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -169,7 +180,10 @@ class TestDeviceGuard(unittest.TestCase):
shape=[1], value=4.0, dtype='float32') shape=[1], value=4.0, dtype='float32')
result = fluid.layers.less_than(x=x, y=y, force_cpu=False) result = fluid.layers.less_than(x=x, y=y, force_cpu=False)
assert len(w) == 2 warning = "\'device_guard\' has higher priority when they are used at the same time."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 2
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
......
...@@ -67,6 +67,13 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp): ...@@ -67,6 +67,13 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp):
self.out = np.floor_divide(self.x, self.y) self.out = np.floor_divide(self.x, self.y)
class TestElementwiseModOpInverse(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)
class TestFloorDivideOp(unittest.TestCase): class TestFloorDivideOp(unittest.TestCase):
def test_name(self): def test_name(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
......
...@@ -21,8 +21,7 @@ from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver ...@@ -21,8 +21,7 @@ from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
import os import os
import sys import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
......
...@@ -86,6 +86,13 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -86,6 +86,13 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.localsgd_configs["k_steps"], 4) self.assertEqual(strategy.localsgd_configs["k_steps"], 4)
self.assertEqual(strategy.localsgd_configs["begin_step"], 120) self.assertEqual(strategy.localsgd_configs["begin_step"], 120)
def test_adaptive_localsgd_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {"init_k_steps": 1, "begin_step": 120}
strategy.adaptive_localsgd_configs = configs
self.assertEqual(strategy.adaptive_localsgd_configs["init_k_steps"], 1)
self.assertEqual(strategy.adaptive_localsgd_configs["begin_step"], 120)
def test_dgc(self): def test_dgc(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.dgc = True strategy.dgc = True
......
...@@ -52,5 +52,36 @@ class TestFleetLocalSGDMetaOptimizer(unittest.TestCase): ...@@ -52,5 +52,36 @@ class TestFleetLocalSGDMetaOptimizer(unittest.TestCase):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
class TestFleetAdaptiveLocalSGDMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
def test_adaptive_localsgd_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.adaptive_localsgd = True
config = strategy.adaptive_localsgd_configs
config['init_k_steps'] = 1
config['begin_step'] = 1
strategy.adaptive_localsgd_configs = config
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -40,9 +40,9 @@ class TestCloudRoleMaker(unittest.TestCase): ...@@ -40,9 +40,9 @@ class TestCloudRoleMaker(unittest.TestCase):
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
from paddle.fluid.incubate.fleet.base.role_maker import \ from paddle.fluid.incubate.fleet.base.role_maker import \
GeneralRoleMaker GeneralRoleMaker
from paddle.distributed.fleet.utils import KVHandler from paddle.distributed.fleet.utils.http_server import KVHandler
from paddle.distributed.fleet.utils import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
from paddle.distributed.fleet.utils import KVHTTPServer from paddle.distributed.fleet.utils.http_server import KVHTTPServer
except: except:
print("warning: no fleet, skip test_pslib_4") print("warning: no fleet, skip test_pslib_4")
return return
......
...@@ -81,12 +81,12 @@ class TestFleetUtil(unittest.TestCase): ...@@ -81,12 +81,12 @@ class TestFleetUtil(unittest.TestCase):
self.assertEqual(user_id, 10) self.assertEqual(user_id, 10)
def test_fs(self): def test_fs(self):
from paddle.distributed.fleet.utils import LocalFS from paddle.distributed.fleet.utils.fs import LocalFS
fs = LocalFS() fs = LocalFS()
dirs, files = fs.ls_dir("test_tmp") dirs, files = fs.ls_dir("test_tmp")
dirs, files = fs.ls_dir("./") dirs, files = fs.ls_dir("./")
self.assertFalse(fs.need_upload_download()) self.assertFalse(fs.need_upload_download())
fleet_util.set_file_system(fs) fleet_util._set_file_system(fs)
def test_barrier(self): def test_barrier(self):
try: try:
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import sys import sys
import inspect import inspect
from paddle.distributed.fleet.utils import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
class FSTest(unittest.TestCase): class FSTest(unittest.TestCase):
......
...@@ -216,7 +216,7 @@ class API_TestGather(unittest.TestCase): ...@@ -216,7 +216,7 @@ class API_TestGather(unittest.TestCase):
"index": index_np, "index": index_np,
'axis': axis_np}, 'axis': axis_np},
fetch_list=[out]) fetch_list=[out])
expected_output = gather_numpy(x_np, index_np, axis_np) expected_output = gather_numpy(x_np, index_np, axis_np[0])
self.assertTrue(np.allclose(result, expected_output)) self.assertTrue(np.allclose(result, expected_output))
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest1(FSTestBase): class FSTest1(FSTestBase):
def test_timeout(self): def test_timeout(self):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest2(FSTestBase): class FSTest2(FSTestBase):
def test_hdfs(self): def test_hdfs(self):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest3(FSTestBase): class FSTest3(FSTestBase):
def test_hdfs(self): def test_hdfs(self):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import numpy as np
class MLP(fluid.Layer):
def __init__(self, input_size):
super(MLP, self).__init__()
self._linear1 = fluid.dygraph.Linear(
input_size,
3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
self._linear2 = fluid.dygraph.Linear(
3,
4,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs):
x = self._linear1(inputs)
x = self._linear2(x)
x = fluid.layers.reduce_sum(x)
return x
class TestDygraphDebugString(unittest.TestCase):
def test_dygraph_debug_string(self):
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
unique_name = 0
trace_var = 0
alive_var = 0
with fluid.dygraph.guard():
mlp = MLP(input_size=2)
for i in range(10):
var_inp = fluid.dygraph.base.to_variable(np_inp)
out = mlp(var_inp)
out.backward()
mlp.clear_gradients()
unique_name_tmp, trace_var_tmp, alive_var_tmp = fluid.dygraph.base._print_debug_msg(
mlp.parameters(), is_test=True)
if i > 0:
self.assertGreaterEqual(unique_name, unique_name_tmp)
self.assertGreaterEqual(trace_var, trace_var_tmp)
self.assertGreaterEqual(alive_var, alive_var_tmp)
else:
unique_name = unique_name_tmp
trace_var = trace_var_tmp
alive_var = alive_var_tmp
try:
fluid.dygraph.base._print_debug_msg(mlp.parameters())
except Exception as e:
raise RuntimeError(
"No Exception is accepted in _print_debug_msg, but we got: {}".
format(e))
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
class TestImperativeUsingNonZeroGpu(unittest.TestCase): class TestImperativeUsingNonZeroGpu(unittest.TestCase):
def run_main(self, np_arr, place): def run_main(self, np_arr, place):
with guard(place): with guard(place):
embedding = Embedding(size=[10, 10])
var = to_variable(np_arr) var = to_variable(np_arr)
self.assertTrue(np.array_equal(np_arr, var.numpy())) self.assertTrue(np.array_equal(np_arr, var.numpy()))
...@@ -30,7 +29,6 @@ class TestImperativeUsingNonZeroGpu(unittest.TestCase): ...@@ -30,7 +29,6 @@ class TestImperativeUsingNonZeroGpu(unittest.TestCase):
return return
np_arr = np.random.random([11, 13]).astype('float32') np_arr = np.random.random([11, 13]).astype('float32')
self.run_main(np_arr, fluid.CUDAPlace(1))
self.run_main(np_arr, fluid.CUDAPlace(0)) self.run_main(np_arr, fluid.CUDAPlace(0))
......
...@@ -64,7 +64,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.batch_size = 128 self.batch_size = 128
self.batch_num = 10 self.batch_num = 10
def train_and_save_model(self): def train_and_save_model(self, only_params=False):
with new_program_scope(): with new_program_scope():
startup_program = fluid.default_startup_program() startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program() main_program = fluid.default_main_program()
...@@ -102,11 +102,15 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -102,11 +102,15 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
static_param_dict[param.name] = fluid.executor._fetch_var( static_param_dict[param.name] = fluid.executor._fetch_var(
param.name) param.name)
fluid.io.save_inference_model( if only_params:
self.save_dirname, ["img"], [prediction], fluid.io.save_params(
exe, exe, self.save_dirname, filename=self.params_filename)
model_filename=self.model_filename, else:
params_filename=self.params_filename) fluid.io.save_inference_model(
self.save_dirname, ["img"], [prediction],
exe,
model_filename=self.model_filename,
params_filename=self.params_filename)
return static_param_dict return static_param_dict
...@@ -120,9 +124,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -120,9 +124,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = None self.params_filename = None
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
configs = paddle.SaveLoadConfig() load_param_dict, _ = paddle.load(self.save_dirname)
configs.separate_params = True
load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
def test_load_with_model_filename(self): def test_load_with_model_filename(self):
...@@ -160,6 +162,14 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -160,6 +162,14 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
load_param_dict, _ = paddle.load(self.save_dirname, configs) load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
def test_load_state_dict_from_save_params(self):
self.save_dirname = "static_mnist.load_state_dict.save_params"
self.params_filename = None
orig_param_dict = self.train_and_save_model(True)
load_param_dict, _ = paddle.load(self.save_dirname)
self.check_load_state_dict(orig_param_dict, load_param_dict)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -67,6 +67,22 @@ class TestSumOp6D(OpTest): ...@@ -67,6 +67,22 @@ class TestSumOp6D(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSumOp8D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64")
}
self.attrs = {'dim': (0, 3)}
self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@skip_check_grad_ci( @skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function," reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.") " its gradient check is not supported by unittest framework.")
...@@ -103,6 +119,40 @@ class TestMinOp(OpTest): ...@@ -103,6 +119,40 @@ class TestMinOp(OpTest):
self.check_output() self.check_output()
class TestMin6DOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.inputs = {
'X': np.random.random((2, 4, 3, 5, 6, 10)).astype("float64")
}
self.attrs = {'dim': [2, 4]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
class TestMin8DOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.inputs = {
'X': np.random.random((2, 4, 3, 5, 6, 3, 2, 4)).astype("float64")
}
self.attrs = {'dim': [2, 3, 4]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
class TestProdOp(OpTest): class TestProdOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_prod" self.op_type = "reduce_prod"
...@@ -116,6 +166,42 @@ class TestProdOp(OpTest): ...@@ -116,6 +166,42 @@ class TestProdOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.inputs = {
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype("float64")
}
self.attrs = {'dim': [2, 3, 4]}
self.outputs = {
'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestProd8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
}
self.attrs = {'dim': [2, 3, 4]}
self.outputs = {
'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestAllOp(OpTest): class TestAllOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_all" self.op_type = "reduce_all"
...@@ -127,12 +213,40 @@ class TestAllOp(OpTest): ...@@ -127,12 +213,40 @@ class TestAllOp(OpTest):
self.check_output() self.check_output()
class TestAll8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_all"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'reduce_all': True, 'dim': (2, 3, 4)}
self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestAllOpWithDim(OpTest): class TestAllOpWithDim(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_all" self.op_type = "reduce_all"
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")}
self.attrs = {'dim': [1]} self.attrs = {'dim': (1, )}
self.outputs = {'Out': self.inputs['X'].all(axis=1)} self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestAll8DOpWithDim(OpTest):
def setUp(self):
self.op_type = "reduce_all"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'dim': (1, 3, 4)}
self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -152,6 +266,23 @@ class TestAllOpWithKeepDim(OpTest): ...@@ -152,6 +266,23 @@ class TestAllOpWithKeepDim(OpTest):
self.check_output() self.check_output()
class TestAll8DOpWithKeepDim(OpTest):
def setUp(self):
self.op_type = "reduce_all"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'dim': (5, ), 'keep_dim': True}
self.outputs = {
'Out': np.expand_dims(
self.inputs['X'].all(axis=self.attrs['dim']), axis=5)
}
def test_check_output(self):
self.check_output()
class TestAllOpError(unittest.TestCase): class TestAllOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -175,6 +306,20 @@ class TestAnyOp(OpTest): ...@@ -175,6 +306,20 @@ class TestAnyOp(OpTest):
self.check_output() self.check_output()
class TestAny8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'reduce_all': True, 'dim': (3, 5, 4)}
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestAnyOpWithDim(OpTest): class TestAnyOpWithDim(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_any" self.op_type = "reduce_any"
...@@ -186,14 +331,45 @@ class TestAnyOpWithDim(OpTest): ...@@ -186,14 +331,45 @@ class TestAnyOpWithDim(OpTest):
self.check_output() self.check_output()
class TestAny8DOpWithDim(OpTest):
def setUp(self):
self.op_type = "reduce_any"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'dim': (3, 6)}
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output()
class TestAnyOpWithKeepDim(OpTest): class TestAnyOpWithKeepDim(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_any" self.op_type = "reduce_any"
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")}
self.attrs = {'dim': [1], 'keep_dim': True} self.attrs = {'dim': (1, ), 'keep_dim': True}
self.outputs = {
'Out': np.expand_dims(
self.inputs['X'].any(axis=self.attrs['dim']), axis=1)
}
def test_check_output(self):
self.check_output()
class TestAny8DOpWithKeepDim(OpTest):
def setUp(self):
self.op_type = "reduce_any"
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.attrs = {'dim': (1, ), 'keep_dim': True}
self.outputs = { self.outputs = {
'Out': np.expand_dims( 'Out': np.expand_dims(
self.inputs['X'].any(axis=1), axis=1) self.inputs['X'].any(axis=self.attrs['dim']), axis=1)
} }
def test_check_output(self): def test_check_output(self):
...@@ -283,6 +459,18 @@ class Test3DReduce3(Test1DReduce): ...@@ -283,6 +459,18 @@ class Test3DReduce3(Test1DReduce):
} }
class Test8DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': (4, 2, 3)}
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class TestKeepDimReduce(Test1DReduce): class TestKeepDimReduce(Test1DReduce):
def setUp(self): def setUp(self):
self.op_type = "reduce_sum" self.op_type = "reduce_sum"
...@@ -294,6 +482,19 @@ class TestKeepDimReduce(Test1DReduce): ...@@ -294,6 +482,19 @@ class TestKeepDimReduce(Test1DReduce):
} }
class TestKeepDim8DReduce(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
}
self.attrs = {'dim': (3, 4, 5), 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']),
keepdims=self.attrs['keep_dim'])
}
class TestReduceAll(Test1DReduce): class TestReduceAll(Test1DReduce):
def setUp(self): def setUp(self):
self.op_type = "reduce_sum" self.op_type = "reduce_sum"
...@@ -302,6 +503,16 @@ class TestReduceAll(Test1DReduce): ...@@ -302,6 +503,16 @@ class TestReduceAll(Test1DReduce):
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
class TestReduceAll(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
}
self.attrs = {'reduce_all': True, 'dim': (3, 4, 5)}
self.outputs = {'Out': self.inputs['X'].sum(axis=self.attrs['dim'])}
@skip_check_grad_ci( @skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function," reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.") " its gradient check is not supported by unittest framework.")
......
...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase):
params_filename='params') params_filename='params')
expected_warn = "no variable in your model, please ensure there are any variables in your model to save" expected_warn = "no variable in your model, please ensure there are any variables in your model to save"
self.assertTrue(len(w) > 0) self.assertTrue(len(w) > 0)
self.assertTrue(expected_warn == str(w[0].message)) self.assertTrue(expected_warn == str(w[-1].message))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -99,6 +99,18 @@ class TestCase7(TestTransposeOp): ...@@ -99,6 +99,18 @@ class TestCase7(TestTransposeOp):
self.axis = (0, 1, 3, 2) self.axis = (0, 1, 3, 2)
class TestCase8(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6, 7)
class TestCase9(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestTransposeOpError(unittest.TestCase): class TestTransposeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -792,15 +792,14 @@ class Model(object): ...@@ -792,15 +792,14 @@ class Model(object):
switched by `paddle.disable_static()`. The usage is as follows. switched by `paddle.disable_static()`. The usage is as follows.
But note, the switching between dynamic and static should be before But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, paddle.static.InputSpec, instantiating a Model. The input description, i.e, paddle.static.InputSpec,
must be required for static graph. must be required.
Args: Args:
network (paddle.nn.Layer): The network is an instance of network (paddle.nn.Layer): The network is an instance of
paddle.nn.Layer. paddle.nn.Layer.
inputs (InputSpec|list|dict|None): `inputs`, entry points of network, inputs (InputSpec|list|dict|None): `inputs`, entry points of network,
could be a InputSpec instance, or lits of InputSpec instances, could be a InputSpec instance, or lits of InputSpec instances,
or dict ({name: InputSpec}), or None. For static graph, or dict ({name: InputSpec}), and it couldn't be None.
inputs must be set. For dynamic graph, it could be None.
labels (InputSpec|list|None): `labels`, entry points of network, labels (InputSpec|list|None): `labels`, entry points of network,
could be a InputSpec instnace or lits of InputSpec instances, could be a InputSpec instnace or lits of InputSpec instances,
or None. For static graph, if labels is required in loss, or None. For static graph, if labels is required in loss,
...@@ -849,10 +848,9 @@ class Model(object): ...@@ -849,10 +848,9 @@ class Model(object):
self._optimizer = None self._optimizer = None
self._test_dataloader = None self._test_dataloader = None
if not in_dygraph_mode(): if not isinstance(inputs, (list, dict, Input)):
if not isinstance(inputs, (list, dict, Input)): raise TypeError(
raise TypeError( "'inputs' must be list or dict, and couldn't be None.")
"'inputs' must be list or dict in static graph mode")
self._inputs = self._verify_spec(inputs, True) self._inputs = self._verify_spec(inputs, True)
self._labels = self._verify_spec(labels) self._labels = self._verify_spec(labels)
...@@ -1004,11 +1002,7 @@ class Model(object): ...@@ -1004,11 +1002,7 @@ class Model(object):
have no variable need to save (like SGD), the fill will not generated). have no variable need to save (like SGD), the fill will not generated).
This function will silently overwrite existing file at the target location. This function will silently overwrite existing file at the target location.
If `training` is set to False, only inference model will be saved. It If `training` is set to False, only inference model will be saved.
should be noted that before using `save`, you should run the model, and
the shape of input you saved is as same as the input of its running.
`@paddle.jit.to_static` must be added on `forward` function of your layer
in dynamic mode now and these will be optimized later.
Args: Args:
path (str): The file prefix to save model. The format is path (str): The file prefix to save model. The format is
...@@ -1037,8 +1031,6 @@ class Model(object): ...@@ -1037,8 +1031,6 @@ class Model(object):
nn.Linear(200, 10), nn.Linear(200, 10),
nn.Softmax()) nn.Softmax())
# If save for inference in dygraph, need this
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
...@@ -1046,7 +1038,7 @@ class Model(object): ...@@ -1046,7 +1038,7 @@ class Model(object):
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
# if use static graph, do not set # if use static graph, do not set
paddle.disable_static(device) if dynamic else None paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x') input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(Mnist(), input, label) model = paddle.Model(Mnist(), input, label)
...@@ -1649,10 +1641,6 @@ class Model(object): ...@@ -1649,10 +1641,6 @@ class Model(object):
model_only=False): model_only=False):
""" """
Save inference model can be in static or dynamic mode. Save inference model can be in static or dynamic mode.
It should be noted that before using `save_inference_model`, you should
run the model, and the shape you saved is as same as the input of its
running. `@paddle.jit.to_static` must be added on `forward` function of
your layer in dynamic mode now and these will be optimized later.
Args: Args:
save_dir (str): The directory path to save the inference model. save_dir (str): The directory path to save the inference model.
...@@ -1678,20 +1666,17 @@ class Model(object): ...@@ -1678,20 +1666,17 @@ class Model(object):
return result_list return result_list
# TODO:
# 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph.
# 2. Save correct shape of input, now the interface stores the shape that the user sent to
# the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
with fluid.framework._dygraph_guard(None): with fluid.framework._dygraph_guard(None):
layer = self.network layer = self.network
layer.forward = paddle.jit.to_static(
layer.forward, input_spec=self._inputs)
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative: if not prog_translator.enable_to_static:
raise RuntimeError( raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False." "save_inference_model doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer): if not isinstance(layer, Layer):
raise TypeError( raise TypeError(
...@@ -1879,18 +1864,7 @@ class Model(object): ...@@ -1879,18 +1864,7 @@ class Model(object):
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []
if specs is None: if isinstance(specs, dict):
# Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`. But how can we know the actual shape of each input tensor?
if is_input:
out_specs = [
Input(
name=n, shape=[None])
for n in extract_args(self.network.forward) if n != 'self'
]
else:
out_specs = to_list(specs)
elif isinstance(specs, dict):
assert is_input == False assert is_input == False
out_specs = [specs[n] \ out_specs = [specs[n] \
for n in extract_args(self.network.forward) if n != 'self'] for n in extract_args(self.network.forward) if n != 'self']
...@@ -1902,8 +1876,8 @@ class Model(object): ...@@ -1902,8 +1876,8 @@ class Model(object):
assert isinstance(spec, Input) assert isinstance(spec, Input)
if spec.name is None: if spec.name is None:
raise ValueError( raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}.". "Requires Input[{}].name != None, but receive `None` with {}."
format(i, spec)) .format(i, spec))
return out_specs return out_specs
......
...@@ -32,6 +32,21 @@ import random ...@@ -32,6 +32,21 @@ import random
import zlib import zlib
import paddle.compat as cpt import paddle.compat as cpt
# On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing,
# Paddle is currently unable to solve this, so forces the process to start using
# the 'fork' start method.
#
# TODO: This solution is not good, because the fork start method could lead to
# crashes of the subprocess. Figure out how to make 'spawn' work.
#
# For more details, please refer to
# https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# https://bugs.python.org/issue33725
if sys.version_info >= (3, 8):
fork_context = multiprocessing.get_context('fork')
else:
fork_context = multiprocessing
def cache(reader): def cache(reader):
""" """
...@@ -560,9 +575,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -560,9 +575,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def queue_reader(): def queue_reader():
queue = multiprocessing.Queue(queue_size) queue = fork_context.Queue(queue_size)
for reader in readers: for reader in readers:
p = multiprocessing.Process( p = fork_context.Process(
target=_read_into_queue, args=(reader, queue)) target=_read_into_queue, args=(reader, queue))
p.start() p.start()
...@@ -593,9 +608,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -593,9 +608,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
def pipe_reader(): def pipe_reader():
conns = [] conns = []
for reader in readers: for reader in readers:
parent_conn, child_conn = multiprocessing.Pipe() parent_conn, child_conn = fork_context.Pipe()
conns.append(parent_conn) conns.append(parent_conn)
p = multiprocessing.Process( p = fork_context.Process(
target=_read_into_pipe, args=(reader, child_conn)) target=_read_into_pipe, args=(reader, child_conn))
p.start() p.start()
......
...@@ -67,35 +67,6 @@ class LeNetDygraph(paddle.nn.Layer): ...@@ -67,35 +67,6 @@ class LeNetDygraph(paddle.nn.Layer):
return x return x
class LeNetDeclarative(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation=None):
super(LeNetDeclarative, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120), Linear(120, 84), Linear(84, 10),
Softmax()) #Todo: accept any activation
@declarative
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class MnistDataset(MNIST): class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None): def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode) super(MnistDataset, self).__init__(mode=mode)
...@@ -444,7 +415,9 @@ class TestModelFunction(unittest.TestCase): ...@@ -444,7 +415,9 @@ class TestModelFunction(unittest.TestCase):
# dynamic saving # dynamic saving
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
fluid.enable_dygraph(device) fluid.enable_dygraph(device)
model = Model(MyModel(classifier_activation=None)) inputs = [InputSpec([None, 20], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(MyModel(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.SGD(learning_rate=0.001, optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters()) parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
...@@ -543,11 +516,10 @@ class TestModelFunction(unittest.TestCase): ...@@ -543,11 +516,10 @@ class TestModelFunction(unittest.TestCase):
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None paddle.disable_static() if dynamic else None
# paddle.disable_static() if dynamic else None
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
prog_translator.enable(False) if not dynamic else None prog_translator.enable(False) if not dynamic else None
net = LeNetDeclarative() net = LeNet()
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')] inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
model = Model(net, inputs) model = Model(net, inputs)
model.prepare() model.prepare()
...@@ -556,8 +528,9 @@ class TestModelFunction(unittest.TestCase): ...@@ -556,8 +528,9 @@ class TestModelFunction(unittest.TestCase):
os.makedirs(save_dir) os.makedirs(save_dir)
tensor_img = np.array( tensor_img = np.array(
np.random.random((1, 1, 28, 28)), dtype=np.float32) np.random.random((1, 1, 28, 28)), dtype=np.float32)
ori_results = model.test_batch(tensor_img)
model.save(save_dir, training=False) model.save(save_dir, training=False)
ori_results = model.test_batch(tensor_img)
fluid.disable_dygraph() if dynamic else None fluid.disable_dygraph() if dynamic else None
place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda( place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda(
...@@ -574,6 +547,7 @@ class TestModelFunction(unittest.TestCase): ...@@ -574,6 +547,7 @@ class TestModelFunction(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
results, ori_results, rtol=1e-5, atol=1e-7) results, ori_results, rtol=1e-5, atol=1e-7)
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
paddle.enable_static()
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
...@@ -585,6 +559,14 @@ class TestRaiseError(unittest.TestCase): ...@@ -585,6 +559,14 @@ class TestRaiseError(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model = Model(net, inputs, labels) model = Model(net, inputs, labels)
def test_input_without_input_spec(self):
for dynamic in [True, False]:
paddle.disable_static() if dynamic else None
net = MyModel(classifier_activation=None)
with self.assertRaises(TypeError):
model = Model(net)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -13,11 +13,10 @@ scipy ; python_version>"3.5" ...@@ -13,11 +13,10 @@ scipy ; python_version>"3.5"
nltk ; python_version>="3.5" nltk ; python_version>="3.5"
rarfile rarfile
Pillow Pillow
graphviz
six six
decorator decorator
prettytable prettytable
objgraph
astor astor
pathlib pathlib
netifaces netifaces ; platform_system != "Windows"
netifaces ; python_version>="3.5" and platform_system == "Windows"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import os
import os.path
import errno
import re
import shutil
import sys
import fnmatch
import errno
import platform
from contextlib import contextmanager
from setuptools import Command
from setuptools import setup, Distribution, Extension
from setuptools.command.install import install as InstallCommandBase
class BinaryDistribution(Distribution):
def has_ext_modules(foo):
return True
RC = 0
ext_name = '.dll' if os.name == 'nt' else ('.dylib' if sys.platform == 'darwin'
else '.so')
def git_commit():
try:
cmd = ['git', 'rev-parse', 'HEAD']
git_commit = subprocess.Popen(
cmd, stdout=subprocess.PIPE,
cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip()
except:
git_commit = 'Unknown'
git_commit = git_commit.decode()
return str(git_commit)
def _get_version_detail(idx):
assert idx < 3, "vesion info consists of %(major)d.%(minor)d.%(patch)d, \
so detail index must less than 3"
if re.match('@TAG_VERSION_REGEX@', '@PADDLE_VERSION@'):
version_details = '@PADDLE_VERSION@'.split('.')
if len(version_details) >= 3:
return version_details[idx]
return 0
def get_major():
return int(_get_version_detail(0))
def get_minor():
return int(_get_version_detail(1))
def get_patch():
return str(_get_version_detail(2))
def is_taged():
try:
cmd = [
'git', 'describe', '--exact-match', '--tags', 'HEAD', '2>/dev/null'
]
git_tag = subprocess.Popen(
cmd, stdout=subprocess.PIPE,
cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip()
git_tag = git_tag.decode()
except:
return False
if str(git_tag).replace('v', '') == '@PADDLE_VERSION@':
return True
else:
return False
def write_version_py(filename='paddle/version.py'):
cnt = '''# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY
#
full_version = '%(major)d.%(minor)d.%(patch)s'
major = '%(major)d'
minor = '%(minor)d'
patch = '%(patch)s'
rc = '%(rc)d'
istaged = %(istaged)s
commit = '%(commit)s'
with_mkl = '%(with_mkl)s'
def show():
if istaged:
print('full_version:', full_version)
print('major:', major)
print('minor:', minor)
print('patch:', patch)
print('rc:', rc)
else:
print('commit:', commit)
def mkl():
return with_mkl
'''
commit = git_commit()
with open(filename, 'w') as f:
f.write(cnt % {
'major': get_major(),
'minor': get_minor(),
'patch': get_patch(),
'rc': RC,
'version': '${PADDLE_VERSION}',
'commit': commit,
'istaged': is_taged(),
'with_mkl': '@WITH_MKL@'
})
write_version_py(filename='@PADDLE_BINARY_DIR@/python/paddle/version.py')
def write_distributed_training_mode_py(
filename='paddle/fluid/incubate/fleet/parameter_server/version.py'):
cnt = '''from __future__ import print_function
# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY
from paddle.fluid.incubate.fleet.base.mode import Mode
BUILD_MODE=Mode.%(mode)s
def is_transpiler():
return Mode.TRANSPILER == BUILD_MODE
'''
dirname = os.path.dirname(filename)
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(filename, 'w') as f:
f.write(cnt %
{'mode': 'PSLIB' if '${WITH_PSLIB}' == 'ON' else 'TRANSPILER'})
write_distributed_training_mode_py(
filename='@PADDLE_BINARY_DIR@/python/paddle/fluid/incubate/fleet/parameter_server/version.py'
)
packages = [
'paddle',
'paddle.libs',
'paddle.utils',
'paddle.dataset',
'paddle.reader',
'paddle.distributed',
'paddle.incubate',
'paddle.incubate.complex',
'paddle.incubate.complex.tensor',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',
'paddle.framework',
'paddle.jit',
'paddle.fluid',
'paddle.fluid.inference',
'paddle.fluid.dygraph',
'paddle.fluid.dygraph.dygraph_to_static',
'paddle.fluid.dygraph.amp',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.distributed',
'paddle.fluid.layers',
'paddle.fluid.dataloader',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.contrib.slim',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.quantization.imperative',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.layers',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.checkpoint',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler',
'paddle.fluid.incubate.fleet.parameter_server.pslib',
'paddle.fluid.incubate.fleet.parameter_server.ir',
'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.utils',
'paddle.hapi',
'paddle.vision',
'paddle.vision.models',
'paddle.vision.transforms',
'paddle.vision.datasets',
'paddle.text',
'paddle.text.datasets',
'paddle.incubate',
'paddle.io',
'paddle.optimizer',
'paddle.nn',
'paddle.nn.functional',
'paddle.nn.layer',
'paddle.nn.initializer',
'paddle.nn.utils',
'paddle.metric',
'paddle.static',
'paddle.static.nn',
'paddle.tensor',
]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
setup_requires = f.read().splitlines()
# Note(wangzhongpu):
# When compiling paddle under python36, the dependencies belonging to python2.7 will be imported, resulting in errors when installing paddle
if sys.version_info >= (3, 6) and sys.version_info < (3, 7):
setup_requires_tmp = []
for setup_requires_i in setup_requires:
if "<\"3.6\"" in setup_requires_i or "<\"3.5\"" in setup_requires_i or "<=\"3.5\"" in setup_requires_i:
continue
setup_requires_tmp += [setup_requires_i]
setup_requires = setup_requires_tmp
if sys.version_info >= (3, 5) and sys.version_info < (3, 6):
setup_requires_tmp = []
for setup_requires_i in setup_requires:
if "<\"3.5\"" in setup_requires_i:
continue
setup_requires_tmp += [setup_requires_i]
setup_requires = setup_requires_tmp
if sys.version_info >= (3, 7):
setup_requires_tmp = []
for setup_requires_i in setup_requires:
if "<\"3.6\"" in setup_requires_i or "<=\"3.6\"" in setup_requires_i or "<\"3.5\"" in setup_requires_i or "<=\"3.5\"" in setup_requires_i or "<\"3.7\"" in setup_requires_i:
continue
setup_requires_tmp += [setup_requires_i]
setup_requires = setup_requires_tmp
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires += ['opencv-python']
# the prefix is sys.prefix which should always be usr
paddle_bins = ''
if not '${WIN32}':
paddle_bins = ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle']
package_data = {
'paddle.fluid':
['${FLUID_CORE_NAME}' + ('.so' if os.name != 'nt' else '.pyd')]
}
if '${HAS_NOAVX_CORE}' == 'ON':
package_data['paddle.fluid'] += [
'core_noavx' + ('.so' if os.name != 'nt' else '.pyd')
]
package_dir = {
'': '${PADDLE_BINARY_DIR}/python',
# The paddle.fluid.proto will be generated while compiling.
# So that package points to other directory.
'paddle.fluid.proto.profiler': '${PADDLE_BINARY_DIR}/paddle/fluid/platform',
'paddle.fluid.proto': '${PADDLE_BINARY_DIR}/paddle/fluid/framework',
'paddle.fluid': '${PADDLE_BINARY_DIR}/python/paddle/fluid',
}
# put all thirdparty libraries in paddle.libs
libs_path = '${PADDLE_BINARY_DIR}/python/paddle/libs'
package_data['paddle.libs'] = []
package_data['paddle.libs'] = [('libwarpctc'
if os.name != 'nt' else 'warpctc') + ext_name]
shutil.copy('${WARPCTC_LIBRARIES}', libs_path)
if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_SHARED_LIB}', libs_path)
shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)
package_data['paddle.libs'] += [
('libmklml_intel' if os.name != 'nt' else 'mklml') + ext_name,
('libiomp5' if os.name != 'nt' else 'libiomp5md') + ext_name
]
else:
if os.name == 'nt':
# copy the openblas.dll
shutil.copy('${OPENBLAS_SHARED_LIB}', libs_path)
package_data['paddle.libs'] += ['openblas' + ext_name]
if '${WITH_LITE}' == 'ON':
shutil.copy('${LITE_SHARED_LIB}', libs_path)
package_data['paddle.libs'] += ['libpaddle_full_api_shared' + ext_name]
if '${WITH_PSLIB}' == 'ON':
shutil.copy('${PSLIB_LIB}', libs_path)
if os.path.exists('${PSLIB_VERSION_PY}'):
shutil.copy(
'${PSLIB_VERSION_PY}',
'${PADDLE_BINARY_DIR}/python/paddle/fluid/incubate/fleet/parameter_server/pslib/'
)
package_data['paddle.libs'] += ['libps' + ext_name]
if '${WITH_MKLDNN}' == 'ON':
if '${CMAKE_BUILD_TYPE}' == 'Release' and os.name != 'nt':
# only change rpath in Release mode.
# TODO(typhoonzero): use install_name_tool to patch mkl libs once
# we can support mkl on mac.
#
# change rpath of libdnnl.so.1, add $ORIGIN/ to it.
# The reason is that all thirdparty libraries in the same directory,
# thus, libdnnl.so.1 will find libmklml_intel.so and libiomp5.so.
command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}"
if os.system(command) != 0:
raise Exception("patch libdnnl.so failed, command: %s" % command)
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
if os.name != 'nt':
shutil.copy('${MKLDNN_SHARED_LIB_1}', libs_path)
package_data['paddle.libs'] += ['libmkldnn.so.0', 'libdnnl.so.1']
else:
package_data['paddle.libs'] += ['mkldnn.dll']
if '${WITH_XPU}' == 'ON':
# only change rpath in Release mode,
if '${CMAKE_BUILD_TYPE}' == 'Release':
if os.name != 'nt':
if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/\" ${XPU_API_LIB}"
else:
command = "patchelf --set-rpath '$ORIGIN/' ${XPU_API_LIB}"
if os.system(command) != 0:
raise Exception("patch ${XPU_API_LIB} failed, command: %s" %
command)
shutil.copy('${XPU_API_LIB}', libs_path)
shutil.copy('${XPU_RT_LIB}', libs_path)
shutil.copy('${XPU_SIM_LIB}', libs_path)
package_data['paddle.libs'] += [
'${XPU_API_LIB_NAME}', '${XPU_RT_LIB_NAME}', '${XPU_SIM_LIB_NAME}'
]
# copy libfuild_framework.so to libs
if os.name != 'nt' and sys.platform != 'darwin':
paddle_framework_lib = '${FLUID_FRAMEWORK_SHARED_LIB}'
shutil.copy(paddle_framework_lib, libs_path)
package_data['paddle.libs'] += [
('libpaddle_framework'
if os.name != 'nt' else 'paddle_framework') + ext_name
]
# remove unused paddle/libs/__init__.py
if os.path.isfile(libs_path + '/__init__.py'):
os.remove(libs_path + '/__init__.py')
package_dir['paddle.libs'] = libs_path
# change rpath of ${FLUID_CORE_NAME}.ext, add $ORIGIN/../libs/ to it.
# The reason is that libwarpctc.ext, libiomp5.ext etc are in paddle.libs, and
# ${FLUID_CORE_NAME}.ext is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries.
# This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213
if '${CMAKE_BUILD_TYPE}' == 'Release':
if os.name != 'nt':
# only change rpath in Release mode, since in Debug mode, ${FLUID_CORE_NAME}.xx is too large to be changed.
if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/../libs/\" ${PADDLE_BINARY_DIR}/python/paddle/fluid/${FLUID_CORE_NAME}" + '.so'
else:
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/${FLUID_CORE_NAME}" + '.so'
# The dynamic library compiled under aarch64 is greater than 64M,
# and an oversize error will be reported when using patchelf.
if platform.machine() != 'aarch64':
if os.system(command) != 0:
raise Exception(
"patch ${FLUID_CORE_NAME}.%s failed, command: %s" %
(ext_name, command))
ext_modules = [Extension('_foo', ['stub.cc'])]
if os.name == 'nt':
# fix the path separator under windows
fix_package_dir = {}
for k, v in package_dir.items():
fix_package_dir[k] = v.replace('/', '\\')
package_dir = fix_package_dir
ext_modules = []
elif sys.platform == 'darwin':
ext_modules = []
def find_files(pattern, root):
for dirpath, _, files in os.walk(root):
for filename in fnmatch.filter(files, pattern):
yield os.path.join(dirpath, filename)
headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/framework')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/imperative')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/memory')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/platform')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/string')) +
list(find_files('*.pb.h', '${PADDLE_BINARY_DIR}/paddle/fluid/platform')) +
list(find_files('*.pb.h', '${PADDLE_BINARY_DIR}/paddle/fluid/framework')) +
list(find_files('*.pb', '${cudaerror_INCLUDE_DIR}'))
+ # errorMessage.pb for errormessage
['${EIGEN_INCLUDE_DIR}/Eigen/Core'] + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/Eigen/src')) + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/unsupported/Eigen')) + # eigen
list(find_files('*', '${GFLAGS_INSTALL_DIR}/include')) + # gflags
list(find_files('*', '${GLOG_INSTALL_DIR}/include')) + # glog
list(find_files('*', '${BOOST_INCLUDE_DIR}/boost')) + # boost
list(find_files('*', '${XXHASH_INSTALL_DIR}/include')) + # xxhash
list(find_files('*', '${PROTOBUF_INCLUDE_DIR}')) + # protobuf
list(find_files('*', '${DLPACK_INCLUDE_DIR}')) + # dlpack
list(find_files('*.h', '${THREADPOOL_INCLUDE_DIR}'))) # threadpool
if '${WITH_MKLDNN}' == 'ON':
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
if '${WITH_GPU}' == 'ON':
headers += list(find_files(
'*.pb', '${cudaerror_INCLUDE_DIR}')) # errorMessage.pb for errormessage
class InstallCommand(InstallCommandBase):
def finalize_options(self):
ret = InstallCommandBase.finalize_options(self)
self.install_headers = os.path.join(self.install_purelib, 'paddle',
'include')
self.install_lib = self.install_platlib
return ret
class InstallHeaders(Command):
"""Override how headers are copied.
"""
description = 'install C/C++ header files'
user_options = [
('install-dir=', 'd', 'directory to install header files to'),
('force', 'f', 'force installation (overwrite existing files)'),
]
boolean_options = ['force']
def initialize_options(self):
self.install_dir = None
self.force = 0
self.outfiles = []
def finalize_options(self):
self.set_undefined_options(
'install', ('install_headers', 'install_dir'), ('force', 'force'))
def mkdir_and_copy_file(self, header):
if 'pb.h' in header:
install_dir = re.sub('${PADDLE_BINARY_DIR}/', '', header)
elif 'third_party' not in header:
# framework
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
patterns = [
'eigen3/src/extern_eigen3', 'boost/src/extern_boost',
'dlpack/src/extern_dlpack/include', 'install/protobuf/include',
'install/gflags/include', 'install/glog/include',
'install/xxhash/include', 'install/mkldnn/include',
'threadpool/src/extern_threadpool'
]
for pattern in patterns:
install_dir = re.sub(pattern, '', install_dir)
install_dir = os.path.join(self.install_dir,
os.path.dirname(install_dir))
if not os.path.exists(install_dir):
self.mkpath(install_dir)
return self.copy_file(header, install_dir)
def run(self):
# only copy third_party/cudaErrorMessage.pb for cudaErrorMessage on mac or windows
if os.name == 'nt' or sys.platform == 'darwin':
if '${WITH_GPU}' == 'ON':
self.mkdir_and_copy_file(
'${cudaerror_INCLUDE_DIR}/cudaErrorMessage.pb')
return
hdrs = self.distribution.headers
if not hdrs:
return
self.mkpath(self.install_dir)
for header in hdrs:
(out, _) = self.mkdir_and_copy_file(header)
self.outfiles.append(out)
def get_inputs(self):
return self.distribution.headers or []
def get_outputs(self):
return self.outfiles
# we redirect setuptools log for non-windows
if sys.platform != 'win32':
@contextmanager
def redirect_stdout():
f_log = open('${SETUP_LOG_FILE}', 'w')
origin_stdout = sys.stdout
sys.stdout = f_log
yield
f_log = sys.stdout
sys.stdout = origin_stdout
f_log.close()
else:
@contextmanager
def redirect_stdout():
yield
if '${WITH_GPU}' == 'ON':
os.environ['PACKAGE_NAME'] = "paddlepaddle-gpu"
else:
os.environ['PACKAGE_NAME'] = "paddlepaddle"
with redirect_stdout():
setup(
name='${PACKAGE_NAME}',
version='${PADDLE_VERSION}',
description='Parallel Distributed Deep Learning',
install_requires=setup_requires,
packages=packages,
ext_modules=ext_modules,
package_data=package_data,
package_dir=package_dir,
scripts=paddle_bins,
distclass=BinaryDistribution,
headers=headers,
cmdclass={
'install_headers': InstallHeaders,
'install': InstallCommand,
},
entry_points={
'console_scripts':
['fleetrun = paddle.distributed.fleet.launch:launch']
})
# As there are a lot of files in purelib which causes many logs,
# we don't print them on the screen, and you can open `setup.py.log`
# for the full logs.
if os.path.exists('${SETUP_LOG_FILE}'):
os.system('grep -v "purelib" ${SETUP_LOG_FILE}')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" For the PR that only modified the unit test, get cases in pull request. """
import os
import json
from github import Github
PADDLE_ROOT = os.getenv('PADDLE_ROOT', '/paddle/')
class PRChecker(object):
""" PR Checker. """
def __init__(self):
self.github = Github(os.getenv('GITHUB_API_TOKEN'), timeout=60)
self.repo = self.github.get_repo('PaddlePaddle/Paddle')
self.pr = None
def init(self):
""" Get pull request. """
pr_id = os.getenv('GIT_PR_ID')
if not pr_id:
print('No PR ID')
exit(0)
self.pr = self.repo.get_pull(int(pr_id))
def get_pr_files(self):
""" Get files in pull request. """
page = 0
file_list = []
while True:
files = self.pr.get_files().get_page(page)
if not files:
break
for f in files:
file_list.append(PADDLE_ROOT + f.filename)
page += 1
return file_list
def get_pr_ut(self):
""" Get unit tests in pull request. """
ut_list = []
file_ut_map = None
cmd = 'wget -q --no-check-certificate https://sys-p0.bj.bcebos.com/prec/file_ut.json'
os.system(cmd)
with open('file_ut.json') as jsonfile:
file_ut_map = json.load(jsonfile)
for f in self.get_pr_files():
if f not in file_ut_map:
return ''
if f.endswith('.h') or f.endswith('.cu'):
return ''
else:
ut_list.extend(file_ut_map.get(f))
ut_list = list(set(ut_list))
return ' '.join(ut_list)
if __name__ == '__main__':
pr_checker = PRChecker()
pr_checker.init()
print(pr_checker.get_pr_ut())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册