提交 d80ae5bc 编写于 作者: Z zlsh80826

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into nvinfer_plugin_exp_merge

......@@ -12,67 +12,122 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include <glog/logging.h>
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace framework {
std::shared_ptr<Generator> Generator::gen_instance_ = NULL;
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed());
VLOG(4) << "initial seed: " << default_cpu_generator->GetCurrentSeed()
<< ", cpu engine: " << default_cpu_generator->GetCPUEngine().get();
return default_cpu_generator;
}
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
}
// NOTE(zhiqiu): there are 3 conditions:
// (1) op seed is not set and DefaultCPUGenerator is inited, use
// DefaultCPUGenerator
// (2) op seed is not set and DefaultCPUGenerator is not inited, use se
// OpDefaultCPUEngine() and set a radnom seed
// (3) op seed is set, use OpDefaultCPUEngine() and set the seed
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
if (DefaultCPUGenerator()->GetIsInitPy() && seed == 0) {
VLOG(4) << "Use random engine from generator";
return DefaultCPUGenerator()->GetCPUEngine();
} else {
// NOTE(zhiqiu): creating an engine instance everytime instead of using
// OpDefaultCPUEngine(), this is the legacy behavior of random operators.
// The benefit is that when runing PE with fixed-seed in multiple thrads,
// each thread has their own engine, and doesn't affect each other.
//
// And we need to measure the determinacy of Generator in PE.
auto engine = std::make_shared<std::mt19937_64>();
if (seed == 0) {
seed = GetRandomSeed();
VLOG(4) << "Use default random engine with random seed = " << seed;
} else {
VLOG(4) << "Use default random engine with fixed random seed = " << seed;
}
static std::mutex mu_;
{
std::lock_guard<std::mutex> lock(mu_);
engine->seed(seed);
}
return engine;
}
}
GeneratorState* Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_.get();
GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_;
return this->state_;
}
void Generator::SetState(GeneratorState* state_in) {
std::lock_guard<std::mutex> lock(this->mutex);
*this->state_ = *state_in;
void Generator::SetState(const GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
}
uint64_t Generator::GetCurrentSeed() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->current_seed;
std::lock_guard<std::mutex> lock(this->mu_);
return this->state_.current_seed;
}
uint64_t Generator::Seed() {
std::lock_guard<std::mutex> lock(this->mutex);
std::lock_guard<std::mutex> lock(this->mu_);
uint64_t seed;
std::random_device de;
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
this->state_->current_seed = seed;
this->state_.current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
this->engine_->seed(seq);
return this->state_->current_seed;
return this->state_.current_seed;
}
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->current_seed = uint64_t(seed);
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
this->engine_->seed(seq);
}
std::mt19937_64& Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine;
std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mu_);
return this->engine_;
}
void Generator::SetCPUEngine(std::mt19937_64 engine) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->cpu_engine = std::mt19937_64(engine);
void Generator::SetCPUEngine(std::shared_ptr<std::mt19937_64> engine) {
std::lock_guard<std::mutex> lock(this->mu_);
this->engine_ = engine;
}
uint64_t Generator::Random64() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine();
std::lock_guard<std::mutex> lock(this->mu_);
auto engine = this->engine_;
return (*engine)();
}
void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
}
bool Generator::GetIsInitPy() const { return this->is_init_py_; }
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <iostream> // temp for debug
......@@ -27,6 +29,12 @@ limitations under the License. */
namespace paddle {
namespace framework {
static uint64_t GetRandomSeed() {
std::random_device rd;
// double has 53 bit significant, so limit uint64 to 53 bits
return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
......@@ -35,62 +43,67 @@ struct GeneratorState {
struct Generator {
Generator() {
GeneratorState default_gen_state_cpu;
default_gen_state_cpu.device = -1;
default_gen_state_cpu.current_seed = 34342423252;
std::seed_seq seq({34342423252});
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq);
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu);
auto seed = GetRandomSeed();
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
}
explicit Generator(uint64_t seed) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
explicit Generator(GeneratorState state_in)
: state_{std::make_shared<GeneratorState>(state_in)} {}
Generator(const Generator& other)
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {}
Generator(const Generator& other) = delete;
// get random state
GeneratorState* GetState();
GeneratorState GetState();
// set random state
void SetState(GeneratorState* state_in);
void SetState(const GeneratorState&);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
uint64_t Seed();
// set seed
void SetCurrentSeed(uint64_t seed);
// get cpu engine
std::mt19937_64& GetCPUEngine();
std::shared_ptr<std::mt19937_64> GetCPUEngine();
// set cpu engine
void SetCPUEngine(std::mt19937_64 engine);
void SetCPUEngine(std::shared_ptr<std::mt19937_64>);
uint64_t Random64();
bool is_init_py = false;
void SetIsInitPy(bool);
bool GetIsInitPy() const;
// CPU Generator singleton
static std::shared_ptr<Generator> GetInstance() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
return gen_instance_;
}
private:
GeneratorState state_;
std::shared_ptr<std::mt19937_64> engine_;
mutable std::mutex mu_;
// NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with
// old seed, and it should be removed after all random-related operators
// and unittests upgrades to use generator.
bool is_init_py_ = false;
};
static std::shared_ptr<Generator> GetInstanceX() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
gen_instance_->is_init_py = true;
return gen_instance_;
}
// The DefaultCPUGenerator is used in manual_seed()
const std::shared_ptr<Generator>& DefaultCPUGenerator();
private:
static std::shared_ptr<Generator> gen_instance_;
std::shared_ptr<GeneratorState> state_;
mutable std::mutex mutex;
// If op seed is set or global is not set, the OpDefaultCPUEngine is used.
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
Generator(const Generator& other, const std::lock_guard<std::mutex>&)
: state_(std::make_shared<GeneratorState>(*(other.state_))) {}
};
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
#include <cmath>
#include <functional>
#include <string>
#include <vector>
......@@ -74,12 +75,17 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
auto* weights_data = weights->mutable_data<float>(platform::CPUPlace());
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array;
// Check for subnormal values that slows down convolution execution
for (int i = 0; i < weights->numel(); ++i) {
if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0;
}
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
......@@ -108,13 +114,6 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
GET_CONV_BN_NODES(conv_ac_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+affinechannel fuse";
return;
}
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
......@@ -143,6 +142,7 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
......
......@@ -36,7 +36,6 @@ namespace paddle {
namespace imperative {
void BasicEngine::Init(VarBase* var, bool retain_graph) {
sorted_sum_gradient_ = FLAGS_sort_sum_gradient;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
......@@ -106,7 +105,7 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
auto& accumulator = accumulators_[var.get()];
if (!accumulator) {
if (sorted_sum_gradient_) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(var.get()));
} else {
accumulator.reset(new EagerGradientAccumulator(var.get()));
......
......@@ -44,7 +44,6 @@ class BasicEngine : public Engine {
private:
std::shared_ptr<GradOpNode> init_node_;
bool sorted_sum_gradient_;
std::unordered_map<GradOpNode*, size_t> node_deps_;
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
......
......@@ -578,7 +578,6 @@ class PartialGradTask {
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
bool sorted_sum_gradient_{FLAGS_sort_sum_gradient};
};
PartialGradTask::PartialGradTask(
......@@ -981,7 +980,7 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {
if (!accumulator) {
accumulator.reset(new GradientAccumulationInfo(
var, sorted_sum_gradient_, create_graph_));
var, FLAGS_sort_sum_gradient, create_graph_));
}
accumulator->IncreaseTotalRefCnt();
......
......@@ -15,7 +15,6 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -103,8 +102,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// params_file_ fields.
CP_MEMBER(opt_cache_dir_);
prog_file_ = std::move(other.prog_file_);
params_file_ = std::move(other.params_file_);
CP_MEMBER(prog_file_);
CP_MEMBER(params_file_);
CP_MEMBER(use_fc_padding_);
// GPU related.
......
......@@ -32,7 +32,6 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -517,6 +516,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
if (config.glog_info_disabled()) {
FLAGS_logtostderr = 1;
FLAGS_minloglevel = 2; // GLOG_ERROR
......@@ -1059,3 +1060,122 @@ USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
#endif
namespace paddle_infer {
void Tensor::Reshape(const std::vector<int> &shape) { tensor_->Reshape(shape); }
std::vector<int> Tensor::shape() const { return tensor_->shape(); }
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
return tensor_->SetLoD(x);
}
std::vector<std::vector<size_t>> Tensor::lod() const { return tensor_->lod(); }
const std::string &Tensor::name() const { return tensor_->name(); }
DataType Tensor::type() const { return tensor_->type(); }
Predictor::Predictor(const Config &config) {
const_cast<Config *>(&config)->SwitchUseFeedFetchOps(false);
// The second parameter indicates that the discard log is not printed
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kAnalysis>(config);
}
std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames();
}
std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetInputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
}
std::vector<std::string> Predictor::GetOutputNames() {
return predictor_->GetOutputNames();
}
std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetOutputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
}
bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
std::unique_ptr<Predictor> Predictor::Clone() {
auto analysis_pred = predictor_->Clone();
std::unique_ptr<Predictor> pred(new Predictor(std::move(analysis_pred)));
return pred;
}
void Predictor::ClearIntermediateTensor() {
predictor_->ClearIntermediateTensor();
}
int GetNumBytesOfDataType(DataType dtype) {
switch (dtype) {
case DataType::FLOAT32:
return sizeof(float);
case DataType::INT64:
return sizeof(int64_t);
case DataType::INT32:
return sizeof(int32_t);
case DataType::UINT8:
return sizeof(uint8_t);
default:
assert(false);
return -1;
}
}
std::string GetVersion() { return paddle::get_version(); }
std::string UpdateDllFlag(const char *name, const char *value) {
return paddle::UpdateDllFlag(name, value);
}
} // namespace paddle_infer
namespace paddle_infer {
std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
std::shared_ptr<Predictor> predictor(new Predictor(config));
return predictor;
}
namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) {
PADDLE_ENFORCE_GE(
size, 1UL,
paddle::platform::errors::InvalidArgument(
"The predictor pool size should be greater than 1, but it's (%d)",
size));
Config copy_config(config);
main_pred_.reset(new Predictor(config));
for (size_t i = 0; i < size - 1; i++) {
if (config.tensorrt_engine_enabled()) {
Config config_tmp(copy_config);
preds_.push_back(
std::move(std::unique_ptr<Predictor>(new Predictor(config_tmp))));
} else {
preds_.push_back(std::move(main_pred_->Clone()));
}
}
}
Predictor *PredictorPool::Retrive(size_t idx) {
PADDLE_ENFORCE_LT(
idx, preds_.size() + 1,
paddle::platform::errors::InvalidArgument(
"There are (%d) predictors in the pool, but the idx is (%d)", idx,
preds_.size() + 1));
if (idx == 0) {
return main_pred_.get();
}
return preds_[idx - 1].get();
}
} // namespace services
} // namespace paddle_infer
......@@ -112,6 +112,12 @@ void PaddleBuf::Free() {
}
}
NativeConfig::NativeConfig() {
LOG(WARNING) << "The paddle::NativeConfig interface is going to be "
"deprecated in the next release, plase use the latest "
"paddle_infer::Config instead.";
}
std::string get_version() {
std::stringstream ss;
ss << "version: " << framework::paddle_version() << "\n";
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
......@@ -25,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -311,6 +313,8 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kNative>(const NativeConfig &config) {
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
VLOG(3) << "create NativePaddlePredictor";
if (config.use_gpu) {
// 1. GPU memory
......
......@@ -347,6 +347,7 @@ class PD_INFER_DECL PaddlePredictor {
/// place of inference, etc.)
///
struct PD_INFER_DECL NativeConfig : public PaddlePredictor::Config {
NativeConfig();
/// GPU related fields.
bool use_gpu{false};
int device{0};
......@@ -421,7 +422,8 @@ enum class PaddleEngineKind {
};
template <typename ConfigT, PaddleEngineKind engine>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
PD_INFER_DECL std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
const ConfigT& config);
template <>
PD_INFER_DECL std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
......@@ -437,6 +439,4 @@ PD_INFER_DECL std::string get_version();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
PD_INFER_DECL std::shared_ptr<framework::Cipher> MakeCipher(
const std::string& config_file);
} // namespace paddle
......@@ -22,9 +22,124 @@ limitations under the License. */
#pragma once
#include <cassert>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle_analysis_config.h" // NOLINT
#include "paddle_api.h" // NOLINT
namespace paddle_infer {
using DataType = paddle::PaddleDType;
using PlaceType = paddle::PaddlePlace;
using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
class PD_INFER_DECL Tensor {
public:
// Can only be created by predictor->GetInputHandle(cosnt std::string& name)
// or predictor->GetOutputHandle(cosnt std::string& name)
Tensor() = delete;
explicit Tensor(std::unique_ptr<paddle::ZeroCopyTensor>&& tensor)
: tensor_(std::move(tensor)) {}
void Reshape(const std::vector<int>& shape);
template <typename T>
void CopyFromCpu(const T* data);
// should add the place
template <typename T>
T* mutable_data(PlaceType place);
template <typename T>
void CopyToCpu(T* data);
template <typename T>
T* data(PlaceType* place, int* size) const;
void SetLoD(const std::vector<std::vector<size_t>>& x);
std::vector<std::vector<size_t>> lod() const;
DataType type() const;
std::vector<int> shape() const;
const std::string& name() const;
private:
std::unique_ptr<paddle::ZeroCopyTensor> tensor_;
};
class PD_INFER_DECL Predictor {
public:
Predictor() = default;
~Predictor() {}
// Use for clone
explicit Predictor(std::unique_ptr<paddle::PaddlePredictor>&& pred)
: predictor_(std::move(pred)) {}
explicit Predictor(const Config& config);
std::vector<std::string> GetInputNames();
std::unique_ptr<Tensor> GetInputHandle(const std::string& name);
bool Run();
std::vector<std::string> GetOutputNames();
std::unique_ptr<Tensor> GetOutputHandle(const std::string& name);
std::unique_ptr<Predictor> Clone();
void ClearIntermediateTensor();
private:
std::unique_ptr<paddle::PaddlePredictor> predictor_;
};
PD_INFER_DECL std::shared_ptr<Predictor> CreatePredictor(
const Config& config); // NOLINT
PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);
PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
template <typename T>
void Tensor::CopyFromCpu(const T* data) {
tensor_->copy_from_cpu<T>(data);
}
template <typename T>
void Tensor::CopyToCpu(T* data) {
return tensor_->copy_to_cpu<T>(data);
}
template <typename T>
T* Tensor::mutable_data(PlaceType place) {
return tensor_->mutable_data<T>(place);
}
template <typename T>
T* Tensor::data(PlaceType* place, int* size) const {
return tensor_->data<T>(place, size);
}
} // namespace paddle_infer
namespace paddle_infer {
namespace services {
class PD_INFER_DECL PredictorPool {
public:
PredictorPool() = delete;
PredictorPool(const PredictorPool&) = delete;
PredictorPool& operator=(const PredictorPool&) = delete;
explicit PredictorPool(const Config& config, size_t size = 1);
Predictor* Retrive(size_t idx);
private:
std::shared_ptr<Predictor> main_pred_;
std::vector<std::unique_ptr<Predictor>> preds_;
};
} // namespace services
} // namespace paddle_infer
......@@ -185,12 +185,14 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>({
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
......
......@@ -515,3 +515,9 @@ if(WITH_MKLDNN)
inference_analysis_test(test_analyzer_capi_ner SRCS analyzer_capi_ner_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model)
if(WITH_GPU)
inference_analysis_test(paddle_infer_api_test SRCS paddle_infer_api_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${RESNET50_MODEL_DIR})
endif()
......@@ -72,3 +72,59 @@ TEST(AnalysisPredictor, use_gpu) {
} // namespace inference
} // namespace paddle
namespace paddle_infer {
TEST(Predictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "model";
Config config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableLiteEngine(PrecisionType::kFloat32);
auto predictor = CreatePredictor(config);
const int batch = 1;
const int channel = 3;
const int height = 318;
const int width = 318;
const int input_num = batch * channel * height * width;
std::vector<float> input(input_num, 1);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, 318, 318});
input_t->CopyFromCpu(input.data());
predictor->Run();
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
size_t out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
std::vector<float> out_data;
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
const std::vector<float> truth_values = {
127.780396f, 738.16656f, 1013.2264f, -438.17206f, 366.4022f,
927.66187f, 736.2241f, -633.68567f, -329.92737f, -430.15637f,
-633.0639f, -146.54858f, -1324.2804f, -1349.3661f, -242.67671f,
117.44864f, -801.7251f, -391.51495f, -404.8202f, 454.16132f,
515.48206f, -133.03114f, 69.293076f, 590.09753f, -1434.6917f,
-1070.8903f, 307.0744f, 400.52573f, -316.12177f, -587.1265f,
-161.05742f, 800.3663f, -96.47157f, 748.708f, 868.17645f,
-447.9403f, 112.73656f, 1127.1992f, 47.43518f, 677.7219f,
593.1881f, -336.4011f, 551.3634f, 397.82474f, 78.39835f,
-715.4006f, 405.96988f, 404.25684f, 246.01978f, -8.430191f,
131.36617f, -648.0528f};
float* data_o = out_data.data();
for (size_t j = 0; j < out_num; j += 10) {
EXPECT_NEAR((data_o[j] - truth_values[j / 10]) / truth_values[j / 10], 0.,
10e-5);
}
}
} // namespace paddle_infer
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle_infer {
TEST(Predictor, use_gpu) {
LOG(INFO) << GetVersion();
UpdateDllFlag("conv_workspace_size_limit", "4000");
std::string model_dir = FLAGS_infer_model + "/model";
Config config;
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableUseGpu(100, 0);
auto predictor = CreatePredictor(config);
auto pred_clone = predictor->Clone();
std::vector<int> in_shape = {1, 3, 318, 318};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
predictor->Run();
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
std::vector<float> out_data;
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
predictor->ClearIntermediateTensor();
}
TEST(PredictorPool, basic) {
LOG(INFO) << GetVersion();
UpdateDllFlag("conv_workspace_size_limit", "4000");
std::string model_dir = FLAGS_infer_model + "/model";
Config config;
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableUseGpu(100, 0);
services::PredictorPool pred_pool(config, 4);
auto pred = pred_pool.Retrive(2);
std::vector<int> in_shape = {1, 3, 318, 318};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
auto in_names = pred->GetInputNames();
auto input_t = pred->GetInputHandle(in_names[0]);
input_t->name();
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
pred->Run();
auto out_names = pred->GetOutputNames();
auto output_t = pred->GetOutputHandle(out_names[0]);
auto out_type = output_t->type();
LOG(INFO) << GetNumBytesOfDataType(out_type);
if (out_type == DataType::FLOAT32) {
PlaceType place;
int size;
output_t->data<float>(&place, &size);
}
}
} // namespace paddle_infer
......@@ -41,7 +41,7 @@ TEST(AnalysisPredictor, use_gpu) {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) {
for (auto &input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs));
predictor->ClearIntermediateTensor();
}
......@@ -49,3 +49,27 @@ TEST(AnalysisPredictor, use_gpu) {
} // namespace inference
} // namespace paddle
namespace paddle_infer {
TEST(PredictorPool, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
Config config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.EnableTensorRtEngine();
services::PredictorPool pred_pool(config, 1);
auto predictor = pred_pool.Retrive(0);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
std::vector<int> in_shape = {1, 3, 224, 224};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
predictor->Run();
}
} // namespace paddle_infer
......@@ -64,11 +64,11 @@ class BernoulliOpKernel<platform::CPUDeviceContext, T>
int64_t size = x->numel();
std::uniform_real_distribution<T> dist(0.0, 1.0);
auto gen_ptr = framework::Generator::GetInstance();
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
auto gen_ptr = framework::DefaultCPUGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
out_data[i] = BernoulliFunctor(in_data[i], dist(gen_engine));
out_data[i] = BernoulliFunctor(in_data[i], dist(*engine));
}
}
}; // namespace operators
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -567,3 +568,14 @@ REGISTER_OP_CPU_KERNEL(
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint(
R"ROC(
Upgrade convtranspose add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
......@@ -14,20 +14,19 @@
#pragma once
#include <ThreadPool.h>
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
......@@ -89,26 +88,17 @@ class UniformInitializer : public Initializer {
min_ = std::stof(attrs[2]);
max_ = std::stof(attrs[3]);
if (seed_ == 0) {
seed_ = std::random_device()();
}
random_engine_.seed(seed_);
dist_ = std::uniform_real_distribution<float>(min_, max_);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float min_;
float max_;
std::minstd_rand random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
......@@ -139,26 +129,18 @@ class GaussianInitializer : public Initializer {
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
if (seed_ == 0) {
seed_ = std::random_device()();
}
random_engine_ = framework::GetCPURandomEngine(seed_);
random_engine_.seed(seed_);
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float std_;
float mean_;
std::minstd_rand random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::normal_distribution<float> dist_;
};
......
......@@ -55,30 +55,22 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
bool init_generator_py = framework::Generator::GetInstance()->is_init_py;
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
std::minstd_rand engine;
int seed_data;
int seed_data = 0;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
}
engine.seed(seed_data);
auto engine = framework::GetCPURandomEngine(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
float cur_random =
init_generator_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
if (cur_random < dropout_prob) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
......
......@@ -39,26 +39,14 @@ class CPUGaussianRandomKernel : public framework::OpKernel<T> {
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
};
}; // namespace operators
template <typename T>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
......
......@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h"
#include <glog/logging.h>
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
......@@ -28,22 +31,17 @@ Sampler::~Sampler() {}
UniformSampler::UniformSampler(int64_t range, unsigned int seed)
: Sampler(range, seed), inv_range_(1.0 / (range + 1)) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
random_engine_ = framework::GetCPURandomEngine(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
int64_t UniformSampler::Sample() const {
return framework::Generator::GetInstance()->is_init_py
? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*dist_)(*random_engine_);
// return (*dist_)(*random_engine_);
}
int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
float UniformSampler::Probability(int64_t value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed)
: Sampler(range, seed), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
random_engine_ = framework::GetCPURandomEngine(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
}
......@@ -52,10 +50,7 @@ int64_t LogUniformSampler::Sample() const {
// inverse_transform_sampling method
// More details:
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
auto cur_random =
framework::Generator::GetInstance()->is_init_py
? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*dist_)(*random_engine_);
auto cur_random = (*dist_)(*random_engine_);
const int64_t value = static_cast<int64_t>(exp(cur_random * log_range_)) - 1;
// Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_.
......@@ -74,7 +69,7 @@ CustomSampler::CustomSampler(int64_t range, const float *probabilities,
const int *alias, const float *alias_probabilities,
unsigned int seed)
: Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
random_engine_ = framework::GetCPURandomEngine(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
......@@ -84,14 +79,8 @@ CustomSampler::CustomSampler(int64_t range, const float *probabilities,
}
int64_t CustomSampler::Sample() const {
auto index =
framework::Generator::GetInstance()->is_init_py
? (*int_dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*int_dist_)(*random_engine_);
auto p =
framework::Generator::GetInstance()->is_init_py
? (*real_dist_)(framework::Generator::GetInstance()->GetCPUEngine())
: (*real_dist_)(*random_engine_);
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > alias_probs_[index]) {
int alias = alias_[index];
......
......@@ -26,8 +26,8 @@ namespace math {
// TODO(wanghaoshuang): Support for GPU
/**
* Sample integers from [0, range).
*/
* Sample integers from [0, range).
*/
class Sampler {
public:
explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
......@@ -117,7 +117,7 @@ class CustomSampler : public Sampler {
const int* alias_;
const float* probs_;
const int exceptional_val = -1;
std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};
......
......@@ -44,6 +44,7 @@ class FCPrimitiveFactory {
void ExecuteFcPrimitive(const LoDTensor* input, const Tensor* weights,
const Tensor* bias, LoDTensor* output,
const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one,
......@@ -74,8 +75,8 @@ class FCPrimitiveFactory {
"input format is equal to ncw."));
}
// Transform weights to default MKL-DNN format
weights_ = TransposeWeights(weights);
weights_ = CreateWeightsMemory(weights);
// Since MKL-DNN has a lot of limitations on what the input/weights/output
// dimensions should be, to simplify the code, the creation of primitive
// descriptor has been divided into separate cases, based on the number
......@@ -112,10 +113,13 @@ class FCPrimitiveFactory {
// Quantize weights and reorder to format chosen by FC primitive descriptor.
QuantizeWeights(ctx, fc_prim_desc->weights_desc());
bias_ = CreateMemory<float>(fc_prim_desc->bias_desc(), bias);
bias_ = CreateMemoryToBeCached<float>(fc_prim_desc->bias_desc(), bias);
// If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(*fc_prim_desc, ctx);
// Store weights and bias in the mkldnn cache
CacheWeightsAndBias(dev_ctx, ctx);
// Based on format determined by inner_product, create output in desired
// memory format
output_ = CreateDstMemory(*fc_prim_desc, ctx, output);
......@@ -262,14 +266,15 @@ class FCPrimitiveFactory {
}
// Convert data from one data format to another
mkldnn::memory Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, void* src_data) {
std::shared_ptr<mkldnn::memory> Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc,
void* src_data) {
auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory(dst_desc, engine_);
auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
auto reorder = mkldnn::reorder(src_mem, dst_mem);
auto reorder = mkldnn::reorder(src_mem, *dst_mem);
mkldnn::stream astream(engine_);
reorder.execute(astream, src_mem, dst_mem);
reorder.execute(astream, src_mem, *dst_mem);
astream.wait();
return dst_mem;
......@@ -277,9 +282,10 @@ class FCPrimitiveFactory {
// Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here.
mkldnn::memory Reorder(const memory& src_mem, const memory::desc& dst_md,
const std::vector<float>& scale_data) {
mkldnn::memory dst_mem = mkldnn::memory(dst_md, engine_);
std::shared_ptr<mkldnn::memory> ReorderWithScale(
const std::shared_ptr<memory> src_mem, const memory::desc& dst_md,
const std::vector<float>& scale_data) {
auto dst_mem = std::make_shared<mkldnn::memory>(dst_md, engine_);
mkldnn::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied.
......@@ -289,11 +295,11 @@ class FCPrimitiveFactory {
// becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data);
auto reorder = mkldnn::reorder(src_mem, dst_mem, attributes);
auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes);
mkldnn::stream astream(engine_);
reorder.execute(astream,
{{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
{{MKLDNN_ARG_FROM, *src_mem}, {MKLDNN_ARG_TO, *dst_mem}});
astream.wait();
return dst_mem;
......@@ -323,16 +329,38 @@ class FCPrimitiveFactory {
return memory(desc, engine_, data);
}
// Transpose weights through MKL-DNN's reorder from io to oi format.
mkldnn::memory TransposeWeights(const Tensor* weights) {
template <typename T>
std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, const Tensor* tensor) {
return CreateMemoryToBeCached(desc,
platform::to_void_cast<T>(tensor->data<T>()));
}
std::shared_ptr<mkldnn::memory> CreateMemoryToBeCached(
const mkldnn::memory::desc& desc, void* data) {
return std::make_shared<memory>(desc, engine_, data);
}
// Create weights memory and transform to default MKL-DNN format
std::shared_ptr<mkldnn::memory> CreateWeightsMemory(const Tensor* weights) {
auto dims = framework::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
// Transpose weights through MKL-DNN's reorder from io to oi format.
return Reorder(src_desc, dst_desc,
platform::to_void_cast<float>(weights->data<float>()));
}
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
const std::string key = platform::CreateKey(platform::ThreadIDasStr());
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
dev_ctx.SetBlob(bias_key, bias_);
}
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) {
......@@ -388,14 +416,14 @@ class FCPrimitiveFactory {
}
void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) {
weights_ =
Reorder(*weights_, dst, ctx.Attr<std::vector<float>>("Scale_weights"));
weights_ = ReorderWithScale(weights_, dst,
ctx.Attr<std::vector<float>>("Scale_weights"));
}
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx);
bias_ = Reorder(*bias_, fc_prim_desc.bias_desc(), bias_scales);
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
}
// Fuse relu into FC with activation type attribute has been set to 'relu'
......@@ -463,10 +491,10 @@ class FCPrimitiveFactory {
private:
const mkldnn::engine& engine_;
boost::optional<memory> bias_;
boost::optional<memory> input_;
boost::optional<memory> output_;
boost::optional<memory> weights_;
std::shared_ptr<memory> bias_;
std::shared_ptr<memory> weights_;
boost::optional<inner_product_forward> fc_;
};
......@@ -476,19 +504,13 @@ class FCPrimitiveFactory {
template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const Tensor* input,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
const std::string& key) {
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator =
std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(mkldnn_engine);
prim_creator = std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetEngine());
dev_ctx.SetBlob(key, prim_creator);
}
......@@ -498,24 +520,24 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename T_in, typename T_w>
static void ExecuteFc(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const LoDTensor* input,
static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, const Tensor* bias, LoDTensor* output,
const mkldnn::engine& mkldnn_engine, bool fuse_relu,
bool force_fp32_output) {
bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w, mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else {
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->ExecuteFcPrimitive(input, w, bias, output, ctx);
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
}
}
......@@ -526,9 +548,6 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
......@@ -537,8 +556,8 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
ExecuteFc<T_in, T_w>(dev_ctx, ctx, input, w, bias, output, mkldnn_engine,
fuse_relu, force_fp32_output);
ExecuteFc<T_in, T_w>(ctx, input, w, bias, output, fuse_relu,
force_fp32_output);
output->set_layout(DataLayout::kMKLDNN);
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
......@@ -35,23 +36,11 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());
int64_t size = tensor->numel();
std::normal_distribution<T> dist(mean, std);
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
tensor->set_layout(DataLayout::kMKLDNN);
......
......@@ -24,49 +24,69 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredGrad"),
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::InvalidArgument(
"Input(Param) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::InvalidArgument(
"Input(Grad) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("AvgSquaredGrad"), true,
platform::errors::InvalidArgument(
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("AvgSquaredUpdate"), true,
platform::errors::InvalidArgument(
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
true,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(),
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredGradOut"),
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredUpdateOut"),
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.");
true,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(),
ctx->GetInputsVarType("Grad").front()));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("ParamOut"), true,
platform::errors::InvalidArgument(
"Output(ParamOut) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("AvgSquaredGradOut"), true,
platform::errors::InvalidArgument(
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("AvgSquaredUpdateOut"), true,
platform::errors::InvalidArgument(
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null."));
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"param and grad input of AdadeltaOp should have same dimension");
PADDLE_ENFORCE_NE(framework::product(ctx->GetInputDim("AvgSquaredGrad")), 0,
"Maybe the Input variable AvgSquaredGrad has not "
"been initialized. You may need to confirm if you put "
"exe.run(startup_program) after optimizer.minimize "
"function.");
PADDLE_ENFORCE_NE(
framework::product(ctx->GetInputDim("AvgSquaredGrad")), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable AvgSquaredGrad has not "
"been initialized. You may need to confirm if you put "
"exe.run(startup_program) after optimizer.minimize "
"function."));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
"Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension");
platform::errors::InvalidArgument(
"Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension"));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
"Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension");
platform::errors::InvalidArgument(
"Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension"));
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
......
......@@ -24,17 +24,19 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type()));
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
......
......@@ -46,22 +46,11 @@ class CPURandintKernel : public framework::OpKernel<T> {
std::uniform_int_distribution<T> dist(ctx.Attr<int>("low"),
ctx.Attr<int>("high") - 1);
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen_engine);
} else {
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
};
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <ctime>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
......@@ -29,20 +30,12 @@ namespace operators {
template <typename T>
static inline void random_permate(T* data_ptr, int num, unsigned int seed) {
auto engine = framework::GetCPURandomEngine(seed);
for (int i = 0; i < num; ++i) {
data_ptr[i] = static_cast<T>(i);
}
if (framework::Generator::GetInstance()->is_init_py) {
std::shuffle(data_ptr, data_ptr + num,
framework::Generator::GetInstance()->GetCPUEngine());
} else {
if (seed == 0) {
seed = std::random_device()();
}
std::srand(seed);
std::random_shuffle(data_ptr, data_ptr + num);
}
std::shuffle(data_ptr, data_ptr + num, *engine);
}
template <typename DeviceContext, typename T>
......
......@@ -51,20 +51,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
framework::TensorToVector(*input, context.device_context(), &ins_vector);
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max")));
auto engine = framework::GetCPURandomEngine(seed);
std::vector<int64_t> ids(batch_size);
for (int i = 0; i < batch_size; ++i) {
T r = framework::Generator::GetInstance()->is_init_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
T r = dist(*engine);
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
......
......@@ -23,22 +23,27 @@ class TopkOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of TopkOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of TopkOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of TopkOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Indices"), true,
platform::errors::InvalidArgument(
"Output(Indices) of TopkOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
PADDLE_ENFORCE_GE(input_dims.size(), 1, platform::errors::InvalidArgument(
"input must have >= 1d shape"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
"input must have >= k columns");
PADDLE_ENFORCE_GE(
input_dims[input_dims.size() - 1], k,
platform::errors::InvalidArgument("input must have >= k columns"));
}
framework::DDim dims = input_dims;
......
......@@ -43,8 +43,9 @@ template <typename DeviceContext, typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <limits>
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -167,22 +168,10 @@ class CPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(gen_engine));
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(engine));
}
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(*engine));
}
}
};
......
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/uniform_random_op.h"
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -62,34 +64,12 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
auto gen_ptr = framework::Generator::GetInstance();
if (gen_ptr->is_init_py) {
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
// auto gen_engine = gen_ptr_->GetCPUEngine();
// std::uniform_real_distribution<T> dist(
// static_cast<T>(ctx.Attr<float>("min")),
// static_cast<T>(ctx.Attr<float>("max")));
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
// std::uniform_real_distribution<T> dist(
// static_cast<T>(ctx.Attr<float>("min")),
// static_cast<T>(ctx.Attr<float>("max")));
// int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
// std::mt19937_64 &engine = gen_ptr->GetCPUEngine();
// auto engine = gen_ptr_->GetCPUEngine();
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
......@@ -139,12 +119,12 @@ class UniformRandomOp : public framework::OperatorWithKernel {
if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT(
inputs_name.size(), 0,
platform::errors::InvalidArgument(
"Input(ShapeTensorList)'size of Op(uniform_random) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)"));
PADDLE_ENFORCE_GT(inputs_name.size(), 0,
platform::errors::InvalidArgument(
"Input(ShapeTensorList)'size of "
"Op(uniform_random) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)"));
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -88,15 +89,12 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
if (framework::Generator::GetInstance()->is_init_py) {
seed = static_cast<unsigned int>(
framework::Generator::GetInstance()->GetCurrentSeed());
} else {
if (seed == 0) {
std::random_device rd;
seed = rd();
}
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
unsigned int diag_num =
......
......@@ -29,23 +29,36 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindGenerator(py::module* m) {
py::class_<framework::GeneratorState>(*m, "GeneratorState", "");
py::class_<std::mt19937_64>(*m, "mt19937_64", "");
void BindGenerator(py::module* m_ptr) {
auto& m = *m_ptr;
py::class_<framework::GeneratorState,
std::shared_ptr<framework::GeneratorState>>(m, "GeneratorState")
.def("current_seed",
[](std::shared_ptr<framework::GeneratorState>& self) {
return self->current_seed;
});
py::class_<std::mt19937_64>(m, "mt19937_64", "");
py::class_<framework::Generator, std::shared_ptr<framework::Generator>>(
*m, "Generator")
.def(py::init([]() { return framework::Generator::GetInstanceX(); }),
py::return_value_policy::reference)
.def("get_state", &framework::Generator::GetState,
py::return_value_policy::move)
m, "Generator")
.def("__init__",
[](framework::Generator& self) {
new (&self) framework::Generator();
})
.def("get_state", &framework::Generator::GetState)
.def("set_state", &framework::Generator::SetState)
.def("manual_seed", &framework::Generator::SetCurrentSeed)
.def("manual_seed",
[](std::shared_ptr<framework::Generator>& self, uint64_t seed) {
self->SetCurrentSeed(seed);
return self;
})
.def("seed", &framework::Generator::Seed)
.def("initial_seed", &framework::Generator::GetCurrentSeed)
.def("random", &framework::Generator::Random64)
.def("get_cpu_engine", &framework::Generator::GetCPUEngine,
py::return_value_policy::move)
.def("set_cpu_engine", &framework::Generator::SetCPUEngine);
// .def("get_cpu_engine", &framework::Generator::GetCPUEngine)
// .def("set_cpu_engine", &framework::Generator::SetCPUEngine)
.def_property("_is_init_py", &framework::Generator::GetIsInitPy,
&framework::Generator::SetIsInitPy);
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
} // end Generator
} // end namespace pybind
} // end namespace paddle
} // namespace paddle
......@@ -206,9 +206,9 @@ void BindInferenceApi(py::module *m) {
BindMkldnnQuantizerConfig(m);
#endif
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<AnalysisConfig>);
&paddle::CreatePaddlePredictor<AnalysisConfig>, py::arg("config"));
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<NativeConfig>);
&paddle::CreatePaddlePredictor<NativeConfig>, py::arg("config"));
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
}
......
......@@ -125,8 +125,15 @@ echo ========================================
echo Step 1. Cmake ...
echo ========================================
echo cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% -DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% -DON_INFER=%ON_INFER% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH%
cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% -DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% -DON_INFER=%ON_INFER% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH%
echo cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% ^
-DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% ^
-DON_INFER=%ON_INFER% -DWITH_INFERENCE_API_TEST=%WITH_INFERENCE_API_TEST% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% ^
-DINFERENCE_DEMO_INSTALL_DIR=%INFERENCE_DEMO_INSTALL_DIR%
cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% ^
-DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% ^
-DON_INFER=%ON_INFER% -DWITH_INFERENCE_API_TEST=%WITH_INFERENCE_API_TEST% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% ^
-DINFERENCE_DEMO_INSTALL_DIR=%INFERENCE_DEMO_INSTALL_DIR%
goto:eof
:cmake_error
......@@ -276,7 +283,10 @@ echo git fetch upstream $BRANCH # develop is not fetched>> check_change_of_
echo fi>> check_change_of_unittest.sh
echo git checkout -b origin_pr >> check_change_of_unittest.sh
echo git checkout -f $BRANCH >> check_change_of_unittest.sh
echo cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% -DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% -DON_INFER=%ON_INFER% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% >> check_change_of_unittest.sh
echo cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% ^
-DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% ^
-DON_INFER=%ON_INFER% -DWITH_INFERENCE_API_TEST=%WITH_INFERENCE_API_TEST% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% ^
-DINFERENCE_DEMO_INSTALL_DIR=%INFERENCE_DEMO_INSTALL_DIR% >> check_change_of_unittest.sh
echo cat ^<^<EOF>> check_change_of_unittest.sh
echo ============================================ >> check_change_of_unittest.sh
echo Generate unit tests.spec of develop. >> check_change_of_unittest.sh
......
......@@ -1399,6 +1399,9 @@ function main() {
local CMD=$1
local parallel_number=$2
init
if [ "$CMD" != "assert_file_approvals" ];then
python ${PADDLE_ROOT}/tools/summary_env.py
fi
case $CMD in
build_only)
cmake_gen_and_build ${PYTHON_ABI:-""} ${parallel_number}
......
......@@ -230,8 +230,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import prepare_context #DEFINE_ALIAS
from .framework import ParallelEnv #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS
......
......@@ -42,9 +42,11 @@ class TestSentimentMethods(unittest.TestCase):
def test_data_set(self):
data_set = st.load_sentiment_data()
last_label = -1
for each in st.test():
self.assertNotEqual(each[1], last_label)
last_label = each[1]
self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES)
self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES)
self.assertEqual(
......
......@@ -12,4 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import spawn
from .spawn import spawn
from . import parallel
from .parallel import init_parallel_env
from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from . import collective
from .collective import *
# start multiprocess apis
__all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env",
"get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
]
# collective apis
__all__ += collective.__all__
......@@ -18,16 +18,15 @@ from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet
from .base.util_factory import UtilBase
from .dataset import *
#from . import metrics
__all__ = [
"DistributedStrategy",
"UtilBase",
"DatasetFactory",
"DatasetBase",
"InMemoryDataset",
"QueueDataset",
"UserDefinedRoleMaker",
"PaddleCloudRoleMaker",
"Fleet",
]
fleet = Fleet()
......
......@@ -17,6 +17,8 @@ from paddle.distributed.fleet.proto import distributed_strategy_pb2
from paddle.fluid.framework import Variable, set_flags, core
import google.protobuf.text_format
__all__ = ["DistributedStrategy"]
def get_msg_dict(msg):
res_dict = {}
......
......@@ -22,7 +22,7 @@ from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
from paddle.fluid.wrapped_decorator import wrap_decorator
__all__ = ['Fleet']
#__all__ = ['Fleet']
def _inited_runtime_handler_(func):
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["MetaOptimizerFactory"]
from ..meta_optimizers import *
meta_optimizer_names = list(
......
......@@ -17,7 +17,7 @@ import numpy as np
from multiprocessing import Process, Manager
import paddle.fluid as fluid
__all__ = ['RoleMakerBase', 'UserDefinedRoleMaker', 'PaddleCloudRoleMaker']
#__all__ = ['UserDefinedRoleMaker', 'PaddleCloudRoleMaker']
class Role:
......
......@@ -22,17 +22,3 @@ from .lars_optimizer import LarsOptimizer
from .async_graph_execution_optimizer import AsyncGraphExecutionOptimizer
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
__all__ = [
'AMPOptimizer',
'RecomputeOptimizer',
'GradientMergeOptimizer',
'AsyncMetaOptimizer',
'GraphExecutionOptimizer',
'PipelineOptimizer',
'LocalSGDOptimizer',
'LarsOptimizer',
'AsyncGraphExecutionOptimizer',
'DGCOptimizer',
'LambOptimizer',
]
......@@ -14,8 +14,6 @@
import paddle.fluid.contrib.mixed_precision as mixed_precision
from .meta_optimizer_base import MetaOptimizerBase
__all__ = ["AMPOptimizer"]
class AMPOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -15,8 +15,6 @@ from paddle.fluid.optimizer import Momentum, DGCMomentumOptimizer
from .meta_optimizer_base import MetaOptimizerBase
import logging
__all__ = ["DGCOptimizer"]
class DGCOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -14,10 +14,6 @@
from paddle.fluid.optimizer import GradientMergeOptimizer as GM
from .meta_optimizer_base import MetaOptimizerBase
__all__ = ["GradientMergeOptimizer"]
# amp + gradient merge + lamb
class GradientMergeOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -16,8 +16,6 @@ from paddle.fluid.optimizer import LambOptimizer as LAMB
from .meta_optimizer_base import MetaOptimizerBase
import logging
__all__ = ["LambOptimizer"]
class LambOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -15,8 +15,6 @@ from paddle.fluid.optimizer import Momentum, LarsMomentumOptimizer
from .meta_optimizer_base import MetaOptimizerBase
import logging
__all__ = ["LarsOptimizer"]
class LarsOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["MetaOptimizerBase"]
from paddle.fluid.optimizer import Optimizer
......
......@@ -20,8 +20,6 @@ from paddle.fluid.optimizer import PipelineOptimizer as PO
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
__all__ = ["PipelineOptimizer"]
class PipelineHelper(CollectiveHelper):
def __init__(self, role_maker, nrings=1, wait_port='6174'):
......
......@@ -14,8 +14,6 @@
from paddle.fluid.optimizer import RecomputeOptimizer as RO
from .meta_optimizer_base import MetaOptimizerBase
__all__ = ["RecomputeOptimizer"]
class RecomputeOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
......
......@@ -11,3 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .metric import *
__all__ = [
"sum",
"max",
"min",
"auc",
"mae",
"rmse",
"mse",
"acc",
]
......@@ -14,5 +14,3 @@
from .collective_runtime import CollectiveRuntime
from .parameter_server_runtime import ParameterServerRuntime
__all__ = ["CollectiveRuntime," "ParameterServerRuntime", ]
......@@ -15,4 +15,4 @@
from .fs import *
from .http_server import KVHandler, KVHTTPServer, KVServer
__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__
#__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__
......@@ -44,11 +44,9 @@ import time
import six
import copy
from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid
from paddle.distributed.utils import *
import paddle.distributed.cloud_utils as cloud_utils
from paddle.distributed import cloud_utils
def _print_arguments(args):
......@@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus):
def get_gpus(selected_gpus):
if selected_gpus is None:
gpus_num = fluid.core.get_cuda_device_count()
from paddle.fluid import core
gpus_num = core.get_cuda_device_count()
selected_gpus = [str(x) for x in range(0, gpus_num)]
else:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
......@@ -190,7 +189,7 @@ def get_gpus(selected_gpus):
return selected_gpus
def launch(args):
def get_cluster_and_pod(args):
# parse arguments, used for cloud-single-machine and local
selected_gpus = get_gpus(args.selected_gpus)
trainers_num = cloud_utils.get_trainers_num()
......@@ -209,6 +208,12 @@ def launch(args):
cluster, pod = get_cluster_from_args(args, selected_gpus)
logger.info("get cluster from args:{}".format(cluster))
return cluster, pod
def launch(args):
cluster, pod = get_cluster_and_pod(args)
procs = start_local_trainers(
cluster,
pod,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import warnings
from paddle import compat as cpt
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv
__all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'):
"""
Initialize parallel training environments in dynamic mode.
Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
dist.spawn(train)
"""
# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 2. check env
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
if var is None:
raise ValueError("paddle.distributed initialize error, "
"environment variable %s is needed, but not set." %
var_name)
_check_var_exists("FLAGS_selected_gpus")
_check_var_exists("PADDLE_TRAINER_ID")
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
def get_rank():
"""
Returns the rank of current trainer.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.
Returns:
(int) The rank of current trainer.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINER_ID=0
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
return ParallelEnv().rank
def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Returns:
(int) The number of trainers.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
"""
return ParallelEnv().world_size
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import multiprocessing
import os
import signal
import six
import sys
import warnings
from paddle.distributed.launch import get_cluster_and_pod, _print_arguments
from paddle.distributed.utils import _prepare_trainer_env
from paddle.device import get_device
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _cpu_num
# NOTE(chenweihang): The existence of this class leads to
# the maintenance of two arguments. When the launch.py arguments
# is updated, the arguments here also need to be updated,
# but I have not thought of a better way here
class ParallelEnvArgs(object):
def __init__(self):
# Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..
self.cluster_node_ips = None
# The current node ip.
self.node_ip = None
# whether to use paddlecloud platform to run your multi-process job.
# If false, no need to set this argument.
self.use_paddlecloud = None
# The trainer's started port on a single node
self.started_port = None
# Print the config or not
self.print_config = True
# It's for gpu training and the training process will run
# on the selected_gpus, each process is bound to a single GPU.
# And if it's not set, this module will use all the gpu cards
# for training.
self.selected_gpus = None
def _py_supported_check():
if not sys.version_info >= (3, 4):
raise RuntimeError(
"Use `paddle.distributed.spawn` to start parallel training "
"requires python version greater than 3.4, if your python "
"is lower than this version, please use "
"`paddle.distributed.launch` instead.")
def _get_subprocess_env_list(nprocs, options):
# contruct processes env list
processes_env_list = []
# get args from kwargs
args = ParallelEnvArgs()
# set default `node_ip` and `cluster_node_ips`
args.cluster_node_ips = options.get('cluster_node_ips', None)
args.node_ip = options.get('node_ip', None)
if args.cluster_node_ips is not None and args.node_ip is None:
raise ValueError("please input current node ip, "
"cannot only give `cluster_node_ips`.")
default_node_ip = "127.0.0.1"
if args.node_ip is None:
args.node_ip = default_node_ip
if args.cluster_node_ips is None:
args.cluster_node_ips = default_node_ip
# set default selected gpus
# e.g. if the nprocs is 4, the selected gpus is "0,1,2,3"
# NOTE(chenweihang): [ why not use FLAGS_selected_gpus directly? ]
# because the FLAGS_selected_gpus may be used in other place,
# if we set FLAGS_selected_gpus to be `0,1,2,3`, it may cause error
# when using `ParallelEnv`
# NOTE(chenweihang): use absolute gpu card id
args.selected_gpus = options.get('selected_gpus', None)
env_devices = os.getenv("CUDA_VISIBLE_DEVICES", None)
if env_devices is None or env_devices == "":
env_devices_list = [
str(x) for x in six.moves.range(core.get_cuda_device_count())
]
else:
env_devices_list = env_devices.split(',')
if args.selected_gpus is None:
if len(env_devices_list) < nprocs:
raise RuntimeError(
"the number of visible devices(%d) is less than the number "
"of spawn processes(%d), please ensure that the correct "
"`nprocs` argument is passed or the environment variable "
"`CUDA_VISIBLE_DEVICES` is correctly configured." %
(len(env_devices_list), nprocs))
args.selected_gpus = ",".join(
[str(env_devices_list[x]) for x in range(0, nprocs)])
else:
for card_id in args.selected_gpus.split(','):
if card_id not in env_devices_list:
raise ValueError("The selected gpu card %s cannot found in "
"CUDA_VISIBLE_DEVICES (%s)." %
(card_id, ",".join(env_devices_list)))
# set other arguments
args.started_port = options.get('started_port', None)
args.use_paddlecloud = options.get('use_paddlecloud', False)
args.print_config = options.get('print_config', False)
# reuse code of launch.py
cluster, pod = get_cluster_and_pod(args)
# prepare subprocess env list
for trainer in pod.trainers:
processes_env_list.append(_prepare_trainer_env(cluster, trainer))
# print config
if args.print_config:
_print_arguments(args)
return processes_env_list
def _remove_risky_env():
# remove useless env vars, same as launch.py
# no copy, each process will hold env vars itself
os.environ.pop("http_proxy", None)
os.environ.pop("https_proxy", None)
def _set_trainer_env(env_dict):
for var_name in env_dict:
os.environ[var_name] = env_dict[var_name]
def _func_wrapper(func, args, error_queue, return_queue, env_dict):
try:
# config subprocess environment variables
_remove_risky_env()
_set_trainer_env(env_dict)
# execute function
result = func(*args)
# record function return value
return_queue.put(result)
except KeyboardInterrupt:
pass
except Exception:
import traceback
error_queue.put(traceback.format_exc())
sys.exit(1)
class MultiprocessContext(object):
def __init__(self, processes, error_queues, return_queues):
_py_supported_check()
self.error_queues = error_queues
# NOTE(chenweihang): The `spawn` method is mainly used
# to wrap the outermost execution function of the program for
# parallel execution. Generally, the return value is not concerned,
# but if the user needs to obtain the return value, users can get
# the return result of each process from context.return_queues
self.return_queues = return_queues
self.processes = processes
self.sentinels = {
process.sentinel: index
for index, process in enumerate(processes)
}
def join(self, timeout=None):
if len(self.sentinels) == 0:
return True
ready = multiprocessing.connection.wait(
self.sentinels.keys(), timeout=timeout)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
if error_index is None:
return len(self.sentinels) == 0
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()
self._throw_exception(error_index)
def _throw_exception(self, error_index):
if self.error_queues[error_index].empty():
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
name = signal.Signals(-exitcode).name
raise Exception("Process %d terminated with signal %s." %
(error_index, name))
else:
raise Exception("Process %d terminated with exit code %d." & (
error_index, exitcode))
original_trace = self.error_queues[error_index].get()
msg = "\n\n----------------------------------------------\n" \
"Process %d terminated with the following error:\n" \
"----------------------------------------------\n\n" % error_index
msg += original_trace
raise Exception(msg)
def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
"""
Start multiple processes with ``spawn`` method for parallel training.
Args:
func (function): The target function is called by spawned process.
This function need to be able to pickled, so it must be defined
at the top level of a module.
This function should be called as ``func(i, *args)``, ``i`` is
the process index and ``args`` contains other arguments as tuple.
args (tuple, optional): Arguments passed to ``func``.
nprocs (int, optional): Number of processed to start. Default: -1.
when nprocs is -1, the available device will be obtained from
the environment variable when the model is executed: If use GPU,
the currently available device ID is obtained from the environment
variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available
CPU number is obtained from the environment variable CPU_NUM.
For example, export CPU_NUM=4, if the environment variable is not set,
the executor will add the variable to the environment variable and
set its value to 1.
join (bool, optional): Perform a blocking join on all spawned processes.
Default: True.
daemon (bool, optional): The spawned processes' daemon flag. Default: False.
**options(dict, optional): Other initial parallel execution environment
configuration options. The following options are currently supported:
(1) start_method (string): the way to start a process.
The start method can be ``spawn`` , ``fork`` , ``forkserver`` .
Because the CUDA runtime does not support the ``fork`` start method,
when use CUDA in subprocesses, we should start process by ``spawn``
or ``forkserver`` method. Default: "spawn" ;
(2) cluster_node_ips (string): Paddle cluster nodes ips, such as
"192.168.0.16,192.168.0.17". Default: "127.0.0.1";
(3) node_ip (string): The current node ip, such as "192.168.0.16".
Default: "127.0.0.1";
(4) started_port (int): The trainer's started port on a single node,
such as 6170. Default: None;
(5) selected_gpus (string): The training process will run on the
selected_gpus, such as "0,1,2,3". Default: None;
(6) print_config: Print current parallel training config. Default: False;
(7) use_paddlecloud: Whether to use paddlecloud platform to run your
multi-process job. Default: False.
Returns:
``MultiprocessContext`` object, it hold the spawned processes.
Examples:
.. code-block:: python
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train(print_result=False):
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
if print_result is True:
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
# Usage 1: only pass function.
# If your training method no need any argument, and
# use all visible devices for parallel training.
if __name__ == '__main__':
dist.spawn(train)
# Usage 2: pass function and arguments.
# If your training method need some arguments, and
# use all visible devices for parallel training.
if __name__ == '__main__':
dist.spawn(train, args=(True,))
# Usage 3: pass function, arguments and nprocs.
# If your training method need some arguments, and
# only use part of visible devices for parallel training.
# If your machine hold 8 cards {0,1,2,3,4,5,6,7},
# this case will use cards {0,1}; If you set
# CUDA_VISIBLE_DEVICES=4,5,6,7, this case will use
# cards {4,5}
if __name__ == '__main__':
dist.spawn(train, args=(True,), nprocs=2)
# Usage 4: pass function, arguments, nprocs and selected_gpus.
# If your training method need some arguments, and
# only use part of visible devices for parallel training,
# but you can't set your machine's environment varibale
# CUDA_VISIBLE_DEVICES, such as it is None or all cards
# {0,1,2,3,4,5,6,7}, you can pass `selelcted_gpus` to
# select the GPU cards you want to use. For example,
# this case will use cards {4,5} if your machine hold 8 cards.
if __name__ == '__main__':
dist.spawn(train, args=(True,), nprocs=2, selelcted_gpus='4,5')
"""
# NOTE(chenweihang): [ why only supports python3.4+ ? ]
# Python supported setting the child process startup method
# since 3.4. The previous version can only use the default startup
# method, while the default startup method of Unix is fork, which
# cannot support CUDA runtime multi-process
_py_supported_check()
# get default nprocs
if nprocs == -1:
device = get_device()
if device == 'cpu':
# TODO: not supports cpu parallel now
nprocs = _cpu_num
else:
nprocs = core.get_cuda_device_count()
# NOTE(chenweihang): [ why need get cluster info before run? ]
# when using `paddle.distributed.spawn` start parallel training,
# we should get cluster info before starting subprocess, and pass
# correct info to each subprocess
procs_env_list = _get_subprocess_env_list(nprocs, options)
# start processes
# NOTE(chenweihang): [ why default start method is spawn? ]
# The CUDA runtime does not support the fork start method,
# either the spawn or forkserver start method are required
# to use CUDA in subprocesses.
start_method = options.get('start_method', None)
if start_method is None:
start_method = 'spawn'
mp = multiprocessing.get_context(start_method)
error_queues = []
return_queues = []
processes = []
for i in range(nprocs):
error_queue = mp.SimpleQueue()
return_queue = mp.SimpleQueue()
process = mp.Process(
target=_func_wrapper,
args=(func, args, error_queue, return_queue, procs_env_list[i]))
process.daemon = daemon
process.start()
error_queues.append(error_queue)
return_queues.append(return_queue)
processes.append(process)
context = MultiprocessContext(processes, error_queues, return_queues)
if not join:
return context
# loop until all process end
while not context.join():
pass
# finally return context
return context
......@@ -327,6 +327,17 @@ def find_free_ports(num):
return None
def _prepare_trainer_env(cluster, trainer):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]),
"PADDLE_TRAINER_ID": "%d" % trainer.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
return proc_env
class TrainerProc(object):
def __init__(self):
self.proc = None
......@@ -352,14 +363,7 @@ def start_local_trainers(cluster,
procs = []
for idx, t in enumerate(pod.trainers):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
proc_env = _prepare_trainer_env(cluster, t)
current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env))
......
......@@ -92,9 +92,11 @@ class TestWeightDecay(unittest.TestCase):
return param_sum
def check_weight_decay(self, place, model):
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
......@@ -113,9 +115,11 @@ class TestWeightDecay(unittest.TestCase):
return param_sum
def check_weight_decay2(self, place, model):
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
......
......@@ -327,19 +327,19 @@ def grad(outputs,
This API computes the sum of gradients of `outputs` with respect to each `inputs` .
Parameters:
outputs (Variable|list(Variable)|tuple(Variable)): the output Variable or
Variable list/tuple of the graph to compute gradients.
inputs (Variable|list(Variable)|tuple(Variable)): the input Variable or
Variable list/tuple of the graph to compute gradients. The returned
outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or
Tensor list/tuple of the graph to compute gradients.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the graph to compute gradients. The returned
values of this API are the gradients of `inputs` .
grad_outputs (Variable|list(Variable|None)|tuple(Variable|None), optional):
grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional):
initial gradient values of `outputs` . If `grad_outputs` is None,
the initial gradient values of `outputs` would be Tensors filled with 1;
if `grad_outputs` is not None, it must have the same length as `outputs` ,
and in this case, the initial gradient value of the i-th `outputs` would
be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Variable. Default None.
`grad_outputs` is a Tensor. Default None.
retain_graph (bool, optional): whether to retain the forward graph which
is used to calculate the gradient. When it is True, the graph would
be retained, in which way users can calculate backward twice for the
......@@ -351,21 +351,21 @@ def grad(outputs,
computing process would be discarded. Default False.
only_inputs (bool, optional): whether to only compute the gradients of
`inputs` . If it is False, the gradients of all remaining leaf
Variables in the graph would be also computed and accumulated.
Tensors in the graph would be also computed and accumulated.
If it is True, only the gradients of `inputs` would be computed.
Default True. only_inputs=False is under development, and it is
not supported yet.
allow_unused (bool, optional): whether to raise error or return None if some
Variables of `inputs` are unreachable in the graph. If some Variables of
Tensors of `inputs` are unreachable in the graph. If some Tensors of
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Variable|list(Variable)|tuple(Variable)|set(Variable), optional):
the Variables whose gradients are not needed to compute. Default None.
no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional):
the Tensors whose gradients are not needed to compute. Default None.
Returns:
tuple: a tuple of Variables, whose length is the same as the Variable number
inside `inputs`, and the i-th returned Variable is the sum of gradients of
tuple: a tuple of Tensors, whose length is the same as the Tensor number
inside `inputs`, and the i-th returned Tensor is the sum of gradients of
`outputs` with respect to the i-th `inputs`.
Examples 1:
......
......@@ -11,21 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np
import warnings
from collections import OrderedDict
from .. import core
from . import layers
from . import parallel_helper
from .. import framework
from . import to_variable, no_grad
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
ParallelStrategy = core.ParallelStrategy
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
def prepare_context(strategy=None):
'''
:api_attr: imperative
......@@ -39,17 +44,18 @@ def prepare_context(strategy=None):
if strategy.nranks < 2:
return
assert framework.in_dygraph_mode() is True, \
"dygraph.prepare_context should be used with dygrahp mode."
"dygraph.prepare_context should be used with dygraph mode."
place = framework._current_expected_place()
assert place is not None, \
"dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
if isinstance(place, core.CUDAPlace):
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
parallel_helper._init_parallel_ctx()
if not parallel_helper._is_parallel_ctx_initialized():
if isinstance(place, core.CUDAPlace):
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
parallel_helper._init_parallel_ctx()
return strategy
......@@ -112,84 +118,84 @@ class ParallelEnv(object):
"""
def __init__(self):
self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._dev_id = int(os.getenv("FLAGS_selected_gpus", "0"))
self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self._device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
@property
def nranks(self):
def rank(self):
"""
The number of trainers, generally refers to the number of GPU cards used in training.
Rank of current trainer.
Its value is equal to the value of the environment variable PADDLE_TRAINERS_NUM. The default value is 1.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
import paddle.fluid as fluid
# execute this command in terminal: export PADDLE_TRAINER_ID=0
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The nranks is %d" % env.nranks)
# The nranks is 4
env = dist.ParallelEnv()
print("The rank is %d" % env.rank)
# The rank is 0
"""
return self._nranks
return self._rank
@property
def local_rank(self):
def world_size(self):
"""
The current trainer number.
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable PADDLE_TRAINER_ID. The default value is 0.
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINER_ID=0
import paddle.fluid as fluid
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The local rank is %d" % env.local_rank)
# The local rank is 0
env = dist.ParallelEnv()
print("The world_size is %d" % env.world_size)
# The world_size is 4
"""
return self._local_rank
return self._world_size
@property
def dev_id(self):
def device_id(self):
"""
The ID of selected GPU card for parallel training.
Its value is equal to the value of the environment variable FLAGS_selected_gpus. The default value is 0.
Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
Examples:
.. code-block:: python
# execute this command in terminal: export FLAGS_selected_gpus=1
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
print("The device id are %d" % env.dev_id)
env = dist.ParallelEnv()
print("The device id are %d" % env.device_id)
# The device id are 1
"""
return self._dev_id
return self._device_id
@property
def current_endpoint(self):
"""
The endpoint of current trainer, it is in the form of (node IP + port).
Its value is equal to the value of the environment variable PADDLE_CURRENT_ENDPOINT. The default value is "".
Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
env = dist.ParallelEnv()
print("The current endpoint are %s" % env.current_endpoint)
# The current endpoint are 127.0.0.1:6170
"""
......@@ -201,20 +207,25 @@ class ParallelEnv(object):
The endpoints of all trainer nodes in the task,
which are used to broadcast the NCCL ID when NCCL2 is initialized.
Its value is equal to the value of the environment variable PADDLE_TRAINER_ENDPOINTS. The default value is "".
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
Examples:
.. code-block:: python
# execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
import paddle.fluid as fluid
import paddle.distributed as dist
env = fluid.dygraph.ParallelEnv()
env = dist.ParallelEnv()
print("The trainer endpoints are %s" % env.trainer_endpoints)
# The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
"""
return self._trainer_endpoints
# [aliases] Compatible with old method names
local_rank = rank
nranks = world_size
dev_id = device_id
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
......@@ -227,61 +238,98 @@ class DataParallel(layers.Layer):
Run the dygraph module with data parallelism.
Currently, DataParallel class only supports to run the dynamic graph
with multi-process. The usage is:
`python -m paddle.distributed.launch --selected_gpus=0,1 dynamic_graph_test.py`.
And the content of `dynamic_graph_test.py` is the code of examples.
with multi-process.
Now supports two ways to start training:
1. start by ``paddle.distributed.spawn`` method, for example:
``python demo.py`` (spawn need to be called in ``__main__`` method)
2. start by ``paddle.distributed.launch`` module, for example:
``python -m paddle.distributed.launch --selected_gpus=0,1 demo.py`` .
And the content of `demo.py` is the code of examples.
Args:
layers(Layer): The module that should be executed by data parallel.
strategy(ParallelStrategy): The strategy of data parallelism, contains
environment configuration related to parallel execution.
strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
contains environment configuration related to parallel execution. Default: None.
Returns:
Layer: The data paralleled module.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):
# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()
linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())
# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)
x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)
hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)
# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)
avg_loss.backward()
# collect the gradients of trainers.
linear.apply_collective_grads()
adam.minimize(avg_loss)
linear.clear_gradients()
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
def __init__(self, layers, strategy):
def __init__(self, layers, strategy=None):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers
self._strategy = strategy
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
# It just stores some environment variables, which can be constructed by
# ParallelEnv. Here it is set as an optional argument.
# This parameter is not removed because of compatibility with 1.x writing.
if strategy is not None:
self._strategy = strategy
else:
self._strategy = ParallelStrategy()
self._strategy.nranks = ParallelEnv().nranks
self._strategy.local_rank = ParallelEnv().local_rank
self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
self._strategy.current_endpoint = ParallelEnv().current_endpoint
def forward(self, *inputs, **kwargs):
return self._layers(*inputs, **kwargs)
......
......@@ -23,6 +23,11 @@ def _is_data_parallel_mode():
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
def _is_parallel_ctx_initialized():
global __parallel_ctx__clz__
return __parallel_ctx__clz__ is not None
def _set_parallel_ctx(nccl_parallel_context):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, \
......
......@@ -17,44 +17,28 @@ from . import core
__all__ = ['Generator']
default_rng_seed_val = 34342423252
class Generator(object):
class Generator(core.Generator):
"""Generator class"""
def __init__(self, device="CPU"):
"""init"""
self.device = device
seed_in = default_rng_seed_val
if self.device == "CPU":
self.generator = core.Generator()
# self.generator.manual_seed(seed_in)
else:
raise ValueError(
"generator class with device %s does not exist, currently only support generator with device 'CPU' "
% device)
def get_state(self):
return self.generator.get_state()
def set_state(self, state):
self.generator.set_state(state)
def __init__(self, place=None):
"""
Create a generator object which manages the random number generation. ( Experimental Feature )
def manual_seed(self, seed):
self.generator.manual_seed(seed)
Parameters:
place(CPUPlace|CUDAPinnedPlace|CUDAPlace, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place.
def seed(self):
return self.generator.seed()
Returns:
Generator: A generator object.
def initial_seed(self):
return self.generator.initial_seed()
def random(self):
return self.generator.random()
def get_cpu_engine(self):
return self.generator.get_cpu_engine()
def set_cpu_engine(self, cpu_engine):
self.generator.set_cpu_engine(cpu_engine)
"""
self.place = place
if not place:
place = core.CPUPlace()
if isinstance(place, core.CPUPlace):
super(Generator, self).__init__()
else:
raise ValueError(
"Generator class with %s does is not supported yet, currently only support generator with CPUPlace "
% place)
......@@ -1858,6 +1858,7 @@ def conv3d(input,
return helper.append_activation(pre_act)
@deprecated(since="2.0.0", update_to="paddle.nn.functional.pool2d")
@templatedoc()
def pool2d(input,
pool_size=-1,
......@@ -2075,6 +2076,7 @@ def pool2d(input,
return pool_out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.pool3d")
@templatedoc()
def pool3d(input,
pool_size=-1,
......@@ -2303,6 +2305,7 @@ def pool3d(input,
return pool_out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.adaptive_pool2d")
@templatedoc(op_type="pool2d")
def adaptive_pool2d(input,
pool_size,
......@@ -2450,6 +2453,7 @@ def adaptive_pool2d(input,
return (pool_out, mask) if require_index else pool_out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.adaptive_pool3d")
@templatedoc(op_type="pool3d")
def adaptive_pool3d(input,
pool_size,
......@@ -10205,6 +10209,7 @@ def unstack(x, axis=0, num=None):
return outs
@deprecated(since='2.0.0', update_to="paddle.expand")
def expand(x, expand_times, name=None):
"""
:alias_main: paddle.expand
......@@ -10312,6 +10317,7 @@ def expand(x, expand_times, name=None):
return out
@deprecated(since='2.0.0', update_to="paddle.expand_as")
def expand_as(x, target_tensor, name=None):
"""
:alias_main: paddle.expand_as
......@@ -10377,6 +10383,9 @@ def expand_as(x, target_tensor, name=None):
#(3,20)
"""
if in_dygraph_mode():
return core.ops.expand_as(x, target_tensor)
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64', 'bool'], 'expand_as')
check_variable_and_dtype(target_tensor, 'target_tensor',
......
......@@ -61,7 +61,14 @@ for i in {1..2}; do
fi
done
echo "dist space:"
df -h
#display /tmp/files
echo "ls /tmp/paddle.*"
ls -l /tmp/paddle.*
echo "ls -l ./"
ls -l ./
exit 1
......@@ -15,6 +15,7 @@
import math
import numpy as np
import unittest
import paddle
from paddle.jit import to_static
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
......@@ -560,8 +561,8 @@ def train_bmn(args, place, to_static):
loss_data = []
with fluid.dygraph.guard(place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
global local_random
local_random = np.random.RandomState(SEED)
......
......@@ -21,6 +21,7 @@ import unittest
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Embedding, Linear, GRUUnit
......@@ -448,8 +449,8 @@ def do_train(args, to_static):
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
reader = get_random_input_data(args.batch_size, args.vocab_size,
args.num_labels)
......
......@@ -14,6 +14,7 @@
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
......@@ -447,8 +448,8 @@ def train_mobilenet(args, to_static):
with fluid.dygraph.guard(args.place):
np.random.seed(SEED)
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
if args.model == "MobileNetV1":
net = MobileNetV1(class_dim=args.class_dim, scale=1.0)
......
......@@ -19,7 +19,7 @@ import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.base import to_variable
......@@ -218,8 +218,8 @@ def train(place):
batch_num = 200
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
......
......@@ -16,6 +16,7 @@ import gym
import math
import itertools
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
from paddle.fluid.dygraph import to_variable, Layer
......@@ -64,8 +65,8 @@ def train(args, place, to_static):
env.seed(SEED)
with fluid.dygraph.guard(place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
local_random = np.random.RandomState(SEED)
policy = Policy()
......
......@@ -215,8 +215,8 @@ def train(to_static):
"""
with fluid.dygraph.guard(place):
np.random.seed(SEED)
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
......
......@@ -331,8 +331,8 @@ def train(train_reader, to_static):
np.random.seed(SEED)
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
se_resnext = SeResNeXt()
optimizer = optimizer_setting(train_parameters, se_resnext.parameters())
......
......@@ -15,6 +15,7 @@ import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Linear, Embedding
from paddle.fluid.dygraph import to_variable, ProgramTranslator, declarative
......@@ -285,8 +286,8 @@ def train(args, to_static):
with fluid.dygraph.guard(place):
np.random.seed(SEED)
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
train_reader = fake_data_reader(args.class_num, args.vocab_size,
args.batch_size, args.padding_size)
......
......@@ -108,8 +108,8 @@ def train(conf_dict, to_static):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
conf_dict['dict_size'] = len(vocab)
conf_dict['seq_len'] = args.seq_len
......
......@@ -18,6 +18,7 @@ import time
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import transformer_util as util
......@@ -31,10 +32,11 @@ STEP_NUM = 10
def train_static(args, batch_generator):
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
train_prog = fluid.Program()
startup_prog = fluid.Program()
train_prog.random_seed = SEED
startup_prog.random_seed = SEED
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
# define input and reader
......@@ -128,8 +130,8 @@ def train_static(args, batch_generator):
def train_dygraph(args, batch_generator):
with fluid.dygraph.guard(place):
if SEED is not None:
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
# define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place)
......@@ -220,7 +222,8 @@ def train_dygraph(args, batch_generator):
def predict_dygraph(args, batch_generator):
with fluid.dygraph.guard(place):
fluid.default_main_program().random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
# define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10)
......@@ -291,7 +294,8 @@ def predict_dygraph(args, batch_generator):
def predict_static(args, batch_generator):
test_prog = fluid.Program()
with fluid.program_guard(test_prog):
test_prog.random_seed = SEED
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
# define input and reader
input_field_names = util.encoder_data_input_fields + util.fast_decoder_data_input_fields
......
......@@ -20,7 +20,7 @@ import random
import sys
import time
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative, ProgramTranslator, to_variable
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm, Linear, Pool2D
......@@ -272,8 +272,8 @@ def train(args, fake_data_reader, to_static):
random.seed(0)
np.random.seed(0)
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
paddle.manual_seed(1000)
paddle.framework.random._manual_program_seed(1000)
video_model = TSM_ResNet("TSM", train_config, 'Train')
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import multiprocessing
import os
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler
......@@ -64,10 +65,11 @@ class TestParallelExecutorBase(unittest.TestCase):
feed_data_reader, FeedDataReader
), "feed_data_reader must be type of FeedDataReader"
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
main = fluid.Program()
startup = fluid.Program()
startup.random_seed = 1
main.random_seed = 1
with fluid.program_guard(main, startup):
feed_dict, loss = cls.build_model(feed_dict, get_data_from_feeder,
main, method, optimizer)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import numpy as np
import unittest
import paddle
# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP
# NOTE: compatible TestParallelDyGraphRunnerBase args
class SpawnAssistTestArgs(object):
update_method = "local"
trainer_id = 0
class TestDistSpawnRunner(unittest.TestCase):
def setUp(self):
# NOTE(chenweihang): keep consistent with
# TestDistBase.check_with_place
self.nprocs = 2
def _run(self, model, args):
args.update_method = "local"
return model.run_trainer_with_spawn(args)
def _run_parallel(self, model, args):
args.update_method = "nccl2"
context = paddle.distributed.spawn(
func=model.run_trainer_with_spawn,
args=(args, ),
nprocs=self.nprocs,
join=True)
result_list = []
for res_queue in context.return_queues:
result_list.append(res_queue.get())
return result_list
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
# 0. prepare model and args
model = test_class()
args = SpawnAssistTestArgs()
# 1. calc signal card loss
losses = self._run(model, args)
# 2. calc multi card loss (nccl mode)
dist_losses_list = self._run_parallel(model, args)
# 3. compare losses
for step_id in range(RUN_STEP):
loss = losses[step_id]
dist_loss_sum = None
for dist_losses in dist_losses_list:
if dist_loss_sum is None:
dist_loss_sum = np.array(dist_losses[step_id])
else:
dist_loss_sum += np.array(dist_losses[step_id])
dist_loss = dist_loss_sum / self.nprocs
self.assertAlmostEqual(
loss,
dist_loss,
delta=delta,
msg="The results of single-card execution and multi-card execution are inconsistent."
"signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
format(loss, dist_loss))
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
class TestAdadeltaOp1(OpTest):
......@@ -108,5 +110,54 @@ class TestAdadeltaOp2(OpTest):
self.check_output()
class TestAdadeltaV2(unittest.TestCase):
def test_adadelta_dygraph(self):
paddle.disable_static(paddle.CPUPlace())
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adadelta(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_adadelta(self):
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
rms_optimizer = paddle.optimizer.Adadelta(learning_rate=0.1)
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.Adadelta, None)
self.assertRaises(
ValueError, paddle.optimizer.Adadelta, learning_rate=0.1, rho=None)
self.assertRaises(
ValueError,
paddle.optimizer.Adadelta,
learning_rate=0.1,
epsilon=None)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
import paddle.nn.functional as F
import paddle.fluid as fluid
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def avg_pool1D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=False,
adaptive=False,
data_type=np.float64):
N, C, L = x.shape
if global_pool == 1:
ksize = [L]
if adaptive:
L_out = ksize[0]
else:
L_out = (L - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
L - ksize[0] + 2 * paddings[0]) // strides[0] + 1
out = np.zeros((N, C, L_out))
for i in range(L_out):
if adaptive:
r_start = adaptive_start_index(i, L, ksize[0])
r_end = adaptive_end_index(i, L, ksize[0])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], L))
x_masked = x[:, :, r_start:r_end]
field_size = (r_end - r_start) \
if (exclusive or adaptive) else (ksize[0])
if data_type == np.int8 or data_type == np.uint8:
out[:, :, i] = (np.rint(
np.sum(x_masked, axis=(2, 3)) / field_size)).astype(data_type)
else:
out[:, :, i] = (np.sum(x_masked, axis=(2)) /
field_size).astype(data_type)
return out
class TestPool1d_API(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_adaptive_avg_dygraph_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = F.adaptive_avg_pool1d(input, output_size=16)
result_np = avg_pool1D_forward_naive(
input_np, ksize=[16], strides=[0], paddings=[0], adaptive=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
ada_max_pool1d_dg = paddle.nn.layer.AdaptiveAvgPool1d(
output_size=16)
result = ada_max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_adaptive_avg_static_results(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32")
result = F.adaptive_avg_pool1d(input, output_size=16)
input_np = np.random.random([2, 3, 32]).astype("float32")
result_np = avg_pool1D_forward_naive(
input_np, ksize=[16], strides=[2], paddings=[0], adaptive=True)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result])
self.assertTrue(np.allclose(fetches[0], result_np))
def test_adaptive_avg_pool1d(self):
for place in self.places:
self.check_adaptive_avg_dygraph_results(place)
self.check_adaptive_avg_static_results(place)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard
import paddle
import paddle.nn.functional as F
import paddle.fluid as fluid
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def max_pool1D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=False,
adaptive=False,
data_type=np.float64):
N, C, L = x.shape
if global_pool == 1:
ksize = [L]
if adaptive:
L_out = ksize[0]
else:
L_out = (L - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
L - ksize[0] + 2 * paddings[0]) // strides[0] + 1
out = np.zeros((N, C, L_out))
for i in range(L_out):
if adaptive:
r_start = adaptive_start_index(i, L, ksize[0])
r_end = adaptive_end_index(i, L, ksize[0])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], L))
x_masked = x[:, :, r_start:r_end]
out[:, :, i] = np.max(x_masked, axis=(2))
return out
class TestPool1d_API(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_adaptive_max_dygraph_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = F.adaptive_max_pool1d(input, output_size=16)
result_np = max_pool1D_forward_naive(
input_np, ksize=[16], strides=[0], paddings=[0], adaptive=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
ada_max_pool1d_dg = paddle.nn.layer.AdaptiveMaxPool1d(
output_size=16)
result = ada_max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_adaptive_max_static_results(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32")
result = F.adaptive_max_pool1d(input, output_size=16)
input_np = np.random.random([2, 3, 32]).astype("float32")
result_np = max_pool1D_forward_naive(
input_np, ksize=[16], strides=[2], paddings=[0], adaptive=True)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result])
self.assertTrue(np.allclose(fetches[0], result_np))
def test_adaptive_max_pool1d(self):
for place in self.places:
self.check_adaptive_max_dygraph_results(place)
self.check_adaptive_max_static_results(place)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def adaptive_pool2d_forward(x, output_size, data_format='NCHW',
pool_type="max"):
N = x.shape[0]
C, H, W = [x.shape[1], x.shape[2], x.shape[3]] if data_format == 'NCHW' \
else [x.shape[3], x.shape[1], x.shape[2]]
if (isinstance(output_size, int) or output_size == None):
H_out = output_size
W_out = output_size
output_size = [H_out, W_out]
else:
H_out, W_out = output_size
if output_size[0] == None:
output_size[0] = H
H_out = H
if output_size[1] == None:
output_size[1] = W
W_out = W
out = np.zeros((N, C, H_out, W_out)) if data_format=='NCHW' \
else np.zeros((N, H_out, W_out, C))
for i in range(H_out):
in_h_start = adaptive_start_index(i, H, output_size[0])
in_h_end = adaptive_end_index(i, H, output_size[0])
for j in range(W_out):
in_w_start = adaptive_start_index(j, W, output_size[1])
in_w_end = adaptive_end_index(j, W, output_size[1])
if data_format == 'NCHW':
x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end]
if pool_type == 'avg':
field_size = (
(in_h_end - in_h_start) * (in_w_end - in_w_start))
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
elif pool_type == 'max':
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
elif data_format == 'NHWC':
x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :]
if pool_type == 'avg':
field_size = (
(in_h_end - in_h_start) * (in_w_end - in_w_start))
out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size
elif pool_type == 'max':
out[:, i, j, :] = np.max(x_masked, axis=(1, 2))
return out
class TestAdaptiveMaxPool2dAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[3, 3], pool_type="max")
self.res_2_np = adaptive_pool2d_forward(
x=self.x_np, output_size=5, pool_type="max")
self.res_3_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[2, 5], pool_type="max")
"""
self.res_4_np = adaptive_pool2d_forward(
x=self.x_np,
output_size=[3, 3],
pool_type="max",
data_format="NHWC")
"""
self.res_5_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[None, 3], pool_type="max")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32")
out_1 = paddle.nn.functional.adaptive_max_pool2d(
x=x, output_size=[3, 3])
out_2 = paddle.nn.functional.adaptive_max_pool2d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_max_pool2d(
x=x, output_size=[2, 5])
#out_4 = paddle.nn.functional.adaptive_max_pool2d(
# x=x, output_size=[3, 3], data_format="NHWC")
out_5 = paddle.nn.functional.adaptive_max_pool2d(
x=x, output_size=[None, 3])
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
#assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
out_1 = paddle.nn.functional.adaptive_max_pool2d(
x=x, return_indices=False, output_size=[3, 3])
out_2 = paddle.nn.functional.adaptive_max_pool2d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_max_pool2d(
x=x, output_size=[2, 5])
#out_4 = paddle.nn.functional.adaptive_max_pool2d(
# x=x, output_size=[3, 3], data_format="NHWC")
out_5 = paddle.nn.functional.adaptive_max_pool2d(
x=x, output_size=[None, 3])
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
#assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
class TestAdaptiveMaxPool2dClassAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[3, 3], pool_type="max")
self.res_2_np = adaptive_pool2d_forward(
x=self.x_np, output_size=5, pool_type="max")
self.res_3_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[2, 5], pool_type="max")
#self.res_4_np = adaptive_pool2d_forward(
# x=self.x_np,
# output_size=[3, 3],
# pool_type="max",
# data_format="NHWC")
self.res_5_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[None, 3], pool_type="max")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32")
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=[3, 3])
out_1 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=5)
out_2 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=[2, 5])
out_3 = adaptive_max_pool(x=x)
# adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(
# output_size=[3, 3], data_format="NHWC")
# out_4 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(
output_size=[None, 3])
out_5 = adaptive_max_pool(x=x)
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
#assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=[3, 3])
out_1 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=5)
out_2 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(output_size=[2, 5])
out_3 = adaptive_max_pool(x=x)
#adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(
# output_size=[3, 3], data_format="NHWC")
#out_4 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool2d(
output_size=[None, 3])
out_5 = adaptive_max_pool(x=x)
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
#assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def adaptive_pool3d_forward(x,
output_size,
adaptive=True,
data_format='NCDHW',
pool_type='max'):
N = x.shape[0]
C, D, H, W = [x.shape[1], x.shape[2], x.shape[3], x.shape[4]] \
if data_format == 'NCDHW' else [x.shape[4], x.shape[1], x.shape[2],x.shape[3]]
if (isinstance(output_size, int) or output_size == None):
H_out = output_size
W_out = output_size
D_out = output_size
output_size = [D_out, H_out, W_out]
else:
D_out, H_out, W_out = output_size
if output_size[0] == None:
output_size[0] = D
D_out = D
if output_size[1] == None:
output_size[1] = H
H_out = H
if output_size[2] == None:
output_size[2] = W
W_out = W
out = np.zeros((N, C, D_out, H_out, W_out)) if data_format=='NCDHW' \
else np.zeros((N, D_out, H_out, W_out, C))
for k in range(D_out):
d_start = adaptive_start_index(k, D, output_size[0])
d_end = adaptive_end_index(k, D, output_size[0])
for i in range(H_out):
h_start = adaptive_start_index(i, H, output_size[1])
h_end = adaptive_end_index(i, H, output_size[1])
for j in range(W_out):
w_start = adaptive_start_index(j, W, output_size[2])
w_end = adaptive_end_index(j, W, output_size[2])
if data_format == 'NCDHW':
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:
w_end]
if pool_type == 'avg':
field_size = (d_end - d_start) * (h_end - h_start) * (
w_end - w_start)
out[:, :, k, i, j] = np.sum(x_masked,
axis=(2, 3, 4)) / field_size
elif pool_type == 'max':
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
elif data_format == 'NDHWC':
x_masked = x[:, d_start:d_end, h_start:h_end, w_start:
w_end, :]
if pool_type == 'avg':
field_size = (d_end - d_start) * (h_end - h_start) * (
w_end - w_start)
out[:, k, i, j, :] = np.sum(x_masked,
axis=(1, 2, 3)) / field_size
elif pool_type == 'max':
out[:, k, i, j, :] = np.max(x_masked, axis=(1, 2, 3))
return out
class TestAdaptiveMaxPool3dAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[3, 3, 3], pool_type="max")
self.res_2_np = adaptive_pool3d_forward(
x=self.x_np, output_size=5, pool_type="max")
self.res_3_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[2, 3, 5], pool_type="max")
self.res_4_np = adaptive_pool3d_forward(
x=self.x_np,
output_size=[3, 3, 3],
pool_type="max",
data_format="NDHWC")
self.res_5_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[None, 3, None], pool_type="max")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32")
out_1 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[3, 3, 3])
out_2 = paddle.nn.functional.adaptive_max_pool3d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[2, 3, 5])
#out_4 = paddle.nn.functional.adaptive_max_pool3d(
# x=x, output_size=[3, 3, 3], data_format="NDHWC")
out_5 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[None, 3, None])
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
#assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
out_1 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[3, 3, 3])
out_2 = paddle.nn.functional.adaptive_max_pool3d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[2, 3, 5])
#out_4 = paddle.nn.functional.adaptive_max_pool3d(
# x=x, output_size=[3, 3, 3], data_format="NDHWC")
out_5 = paddle.nn.functional.adaptive_max_pool3d(
x=x, output_size=[None, 3, None])
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
#assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
class TestAdaptiveMaxPool3dClassAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[3, 3, 3], pool_type="max")
self.res_2_np = adaptive_pool3d_forward(
x=self.x_np, output_size=5, pool_type="max")
self.res_3_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[2, 3, 5], pool_type="max")
# self.res_4_np = adaptive_pool3d_forward(
# x=self.x_np,
# output_size=[3, 3, 3],
# pool_type="max",
# data_format="NDHWC")
self.res_5_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[None, 3, None], pool_type="max")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32")
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[3, 3, 3])
out_1 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(output_size=5)
out_2 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[2, 3, 5])
out_3 = adaptive_max_pool(x=x)
# adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
# output_size=[3, 3, 3], data_format="NDHWC")
# out_4 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[None, 3, None])
out_5 = adaptive_max_pool(x=x)
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
# assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[3, 3, 3])
out_1 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(output_size=5)
out_2 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[2, 3, 5])
out_3 = adaptive_max_pool(x=x)
# adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
# output_size=[3, 3, 3], data_format="NDHWC")
# out_4 = adaptive_max_pool(x=x)
adaptive_max_pool = paddle.nn.AdaptiveMaxPool3d(
output_size=[None, 3, None])
out_5 = adaptive_max_pool(x=x)
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
# assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
if __name__ == '__main__':
unittest.main()
......@@ -85,10 +85,35 @@ class TestBatchNorm(unittest.TestCase):
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v3(x, is_test, trainable_statistics):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(
shape[1],
is_test=is_test,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0),
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
trainable=False),
trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v4(x):
with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2d(
shape[1], weight_attr=False, bias_attr=False)
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x, False, False)
y2 = compute_v2(x)
y3 = compute_v3(x, False, False)
y4 = compute_v4(x)
self.assertTrue(np.allclose(y1, y2))
self.assertTrue(np.allclose(y3, y4))
def test_static(self):
places = [fluid.CPUPlace()]
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Parameter
import numpy as np
......@@ -44,10 +45,10 @@ class InplaceTestBase(unittest.TestCase):
def build_program_and_scope(self):
self.place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
startup_program = fluid.Program()
main_program = fluid.Program()
startup_program.random_seed = 1
main_program.random_seed = 1
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from test_imperative_base import new_program_scope
......@@ -29,8 +30,8 @@ class TestCompiledProgram(unittest.TestCase):
self.label = np.random.randint(
low=0, high=10, size=[16, 1], dtype=np.int64)
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
paddle.manual_seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -46,8 +47,8 @@ class TestCompiledProgram(unittest.TestCase):
def test_compiled_program_base(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
paddle.manual_seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -64,8 +65,8 @@ class TestCompiledProgram(unittest.TestCase):
def test_compiled_program_with_data_parallel(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
paddle.manual_seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -34,10 +34,10 @@ def random_reader():
def simple_fc_net(places, use_legacy_py_reader, use_double_buffer):
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
startup_prog = fluid.Program()
main_prog = fluid.Program()
startup_prog.random_seed = 1
main_prog.random_seed = 1
with fluid.unique_name.guard():
with fluid.program_guard(main_prog, startup_prog):
......
......@@ -38,9 +38,10 @@ class TestDirectory(unittest.TestCase):
'paddle.enable_static', 'paddle.disable_static',
'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad',
'paddle.no_grad', 'paddle.save', 'paddle.load',
'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv',
'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit',
'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.static.save', 'paddle.static.load',
'paddle.distributed.ParallelEnv',
'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay',
......
......@@ -23,8 +23,11 @@ import subprocess
import six
import argparse
import pickle
import random
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.dygraph as dygraph
......@@ -382,22 +385,22 @@ class TestParallelDyGraphRunnerBase(object):
raise NotImplementedError(
"train_one_loop should be implemented by the child classes.")
def _get_data(self, batch, args):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
def run_trainer(self, args):
seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
def _get_data(batch):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -422,7 +425,7 @@ class TestParallelDyGraphRunnerBase(object):
out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()):
data = _get_data(data)
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
......@@ -444,6 +447,47 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed = seed
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method == "nccl2":
paddle.distributed.init_parallel_env()
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
model = paddle.DataParallel(model)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
if args.update_method == "nccl2":
loss = model.scale_loss(loss)
loss.backward()
if args.update_method == "nccl2":
model.apply_collective_grads()
opt.minimize(loss)
model.clear_gradients()
return out_losses
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.')
......
......@@ -27,6 +27,8 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
SEED = 123123111
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
......@@ -105,12 +107,11 @@ class MNIST(fluid.dygraph.Layer):
class TestDygraphMultiForward(unittest.TestCase):
def test_mnist_forward_float32(self):
seed = 90
epoch_num = 1
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
with fluid.dygraph.guard():
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
mnist = MNIST()
sgd = SGDOptimizer(
learning_rate=1e-3, parameter_list=mnist.parameters())
......@@ -142,9 +143,8 @@ class TestDygraphMultiForward(unittest.TestCase):
dy_param_init_value[param.name] = param.numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
......@@ -465,9 +466,9 @@ class PaddingRNNTestBase(unittest.TestCase):
pass
def _prepare_program(self, config, parallel=True):
paddle.manual_seed(config.random_seed)
self.main_program = fluid.Program()
self.startup_program = fluid.Program()
self.startup_program.random_seed = config.random_seed
with fluid.program_guard(self.main_program, self.startup_program):
with fluid.unique_name.guard():
res_vars = lm_model(
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import six
import unittest
......@@ -37,13 +38,13 @@ class TestEmbeddingIdStopGradientBase(unittest.TestCase):
self.assertTrue(np.array_equal(grad_value1, grad_value2))
def run_program(self, place, stop_gradient=False):
np.random.seed(1)
paddle.manual_seed(1)
paddle.framework.random._manual_program_seed(1)
startup_program = fluid.Program()
main_program = fluid.Program()
np.random.seed(1)
startup_program.random_seed = 1
main_program.random_seed = 1
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.scope_guard(scope):
......
......@@ -102,8 +102,23 @@ class TestExpandAsOpRank4(OpTest):
self.check_grad(['X'], 'Out')
# Test dygraph API
class TestExpandAsDygraphAPI(unittest.TestCase):
def test_api(self):
import paddle
paddle.disable_static()
np_data_x = np.array([1, 2, 3]).astype('int32')
np_data_y = np.array([1, 2, 3, 1, 2, 3]).astype('int32')
data_x = paddle.to_tensor(np_data_x)
data_y = paddle.to_tensor(np_data_y)
out = fluid.layers.expand_as(data_x, data_y)
np_out = out.numpy()
assert np.array_equal(np_out, np.tile(np_data_x, (2)))
paddle.enable_static()
# Test python API
class TestExpandAPI(unittest.TestCase):
class TestExpandAsAPI(unittest.TestCase):
def test_api(self):
input1 = np.random.random([12, 14]).astype("float32")
input2 = np.random.random([48, 14]).astype("float32")
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import unittest
import paddle
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
......@@ -135,31 +136,32 @@ class TestFCOpWithPadding(TestFCOp):
class TestFcOp_NumFlattenDims_NegOne(unittest.TestCase):
def test_api(self):
startup_program = Program()
main_program = Program()
startup_program.random_seed = SEED
main_program.random_seed = SEED
with program_guard(main_program, startup_program):
input = np.random.random([2, 2, 25]).astype("float32")
x = fluid.layers.data(
name="x",
shape=[2, 2, 25],
append_batch_size=False,
dtype="float32")
out_1 = fluid.layers.fc(input=x, size=1, num_flatten_dims=-1)
out_2 = fluid.layers.fc(input=x, size=1, num_flatten_dims=2)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place=place)
exe.run(startup_program)
res_1, res_2 = exe.run(main_program,
feed={"x": input},
fetch_list=[out_1, out_2])
assert np.array_equal(res_1, res_2)
def run_program(num_flatten_dims):
paddle.manual_seed(SEED)
startup_program = Program()
main_program = Program()
with program_guard(main_program, startup_program):
input = np.random.random([2, 2, 25]).astype("float32")
x = fluid.layers.data(
name="x",
shape=[2, 2, 25],
append_batch_size=False,
dtype="float32")
out = fluid.layers.fc(input=x,
size=1,
num_flatten_dims=num_flatten_dims)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place=place)
exe.run(startup_program)
out = exe.run(main_program, feed={"x": input}, fetch_list=[out])
res_1 = run_program(-1)
res_2 = run_program(2)
self.assertTrue(np.array_equal(res_1, res_2))
class TestFCOpError(unittest.TestCase):
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册