提交 f4cc5881 编写于 作者: F frankwhzhang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into bpr

add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(imperative)
add_subdirectory(operators)
add_subdirectory(string)
add_subdirectory(recordio)
......
......@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
}
}
......@@ -301,6 +297,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
......@@ -337,6 +334,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
(*ins_vec)[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
......@@ -348,36 +346,25 @@ void MultiSlotDataFeed::PutToFeedVec(
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr =
feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
}
}
}
......
......@@ -30,35 +30,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class MixTensor {
public:
MixTensor() {}
explicit MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
explicit MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() { return is_dense_; }
LoDTensor* GetLoDTensor() {
PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
return lodtensor_;
}
Tensor* GetTensor() {
PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
......@@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_; // -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here
std::vector<MixTensor> feed_vec_;
std::vector<LoDTensor*> feed_vec_;
// the batch size defined by user
int default_batch_size_;
......
......@@ -152,19 +152,13 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
std::map<std::string, const paddle::framework::LoDTensor*>
lodtensor_targets;
std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
const auto& name = slot.name();
readers[idx]->AddFeedVar(scope->Var(name), name);
if (slot.is_dense()) {
tensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::Tensor>();
} else {
lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
}
lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
}
}
readers[idx]->Start();
......@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if (!slot.is_used()) {
continue;
}
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.is_dense()) { // dense branch
const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
int batch_size = tens->dims()[0];
......@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW("Error type in proto file.");
}
} else { // sparse branch
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
for (size_t i = 0; i < tens->NumElements(); ++i) {
......
......@@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() {
static unsigned concurrency_cap = std::thread::hardware_concurrency();
int thread_id = this->thread_id_;
if (thread_id < concurrency_cap) {
if (static_cast<unsigned>(thread_id) < concurrency_cap) {
unsigned proc = thread_id;
cpu_set_t mask;
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
......@@ -53,5 +55,12 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
return tensor;
}
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name) {
Variable* var = scope.FindVar(var_name);
PADDLE_ENFORCE(var, "%s no in scope", var_name);
PADDLE_ENFORCE(var->IsType<LoDTensor>(), "Only support lod tensor now.");
return *var->GetMutable<LoDTensor>();
}
} // namespace framework
} // namespace paddle
......@@ -27,5 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index);
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name);
} // namespace framework
} // namespace paddle
......@@ -38,9 +38,8 @@ void CheckProgram(const ProgramDesc &program) {
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR)
<< "Cannot add backward operator before forward operator %s."
<< op->Type();
LOG(ERROR) << "Cannot add backward operator before forward operator "
<< op->Type();
}
break;
case _INT(OpRole::kBackward):
......
cc_library(layer SRCS layer.cc DEPS proto_desc operator)
cc_library(tracer SRCS tracer.cc DEPS proto_desc)
cc_library(engine SRCS engine.cc)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/engine.h"
#include <mutex> // NOLINT
#include <vector>
#include "glog/logging.h"
namespace paddle {
namespace imperative {
static std::once_flag init_engine;
static Engine* engine;
class DummyEngine : public Engine {
public:
void Enqueue(Runnable* runnable) override {
queued_runnables_.push_back(runnable);
}
size_t Size() const override { return queued_runnables_.size(); }
void Sync() override {
for (Runnable* l : queued_runnables_) {
LOG(INFO) << "running " << reinterpret_cast<void*>(l);
}
queued_runnables_.clear();
}
private:
std::vector<Runnable*> queued_runnables_;
};
Engine* GetEngine() {
std::call_once(init_engine, []() { engine = new DummyEngine(); });
return engine;
}
} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdint>
namespace paddle {
namespace imperative {
struct Runnable {};
class Engine {
public:
virtual ~Engine() {}
virtual void Enqueue(Runnable* runnable) = 0;
virtual size_t Size() const = 0;
virtual void Sync() = 0;
};
Engine* GetEngine();
} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/layer.h"
#include <deque>
#include <limits>
#include <map>
#include <random>
#include <utility>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace imperative {
using framework::Variable;
void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld",
dst_tensor->numel(), src_tensor->numel());
float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace());
const float* src_data = src_tensor->data<float>();
for (size_t i = 0; i < src_tensor->numel(); ++i) {
dst_data[i] += src_data[i];
}
}
class Autograd {
public:
explicit Autograd(framework::Scope* scope) : scope_(scope) {}
void RunBackward(VarBase* var) {
PADDLE_ENFORCE(var->pre_op_->op_desc_);
// TODO(panyx0718): Only create for vars that "require_grad"
(*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_;
std::deque<OpBase*> ready;
ready.push_back(var->pre_op_);
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->pre_op_);
while (!ready.empty()) {
OpBase* ready_op = ready.front();
ready.pop_front();
std::vector<Variable*> input_grads = ready_op->ApplyGrad(scope_);
for (size_t i = 0; i < input_grads.size(); ++i) {
if (!input_grads[i]) continue;
OpBase* pre_op = ready_op->pre_ops_->at(i);
if (!pre_op) continue;
dep_counts[pre_op] -= 1;
PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
bool pre_op_ready = dep_counts[pre_op] == 0;
if (pre_op_ready) {
ready.push_back(pre_op);
}
}
}
}
private:
std::map<OpBase*, int> ComputeDepCounts(OpBase* op) {
std::map<OpBase*, int> ret;
std::deque<OpBase*> queue;
queue.push_back(op);
std::unordered_set<OpBase*> visited;
visited.insert(op);
while (!queue.empty()) {
OpBase* candidate = queue.front();
queue.pop_front();
for (OpBase* pre_op : *(candidate->pre_ops_)) {
if (!pre_op) continue;
if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op);
queue.push_back(pre_op);
}
ret[pre_op] += 1;
}
}
return ret;
}
framework::Scope* scope_;
};
framework::Variable* CreateVariable(const std::string& name,
const framework::DDim& dim, float val,
framework::Scope* scope,
bool random_name = true) {
std::string varname = name;
if (random_name) {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int>::max());
int id = dist6(rng);
varname = string::Sprintf("%s@%d", varname, id);
}
VLOG(3) << "creating var " << varname;
framework::Variable* var = scope->Var(varname);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
float* data = tensor->mutable_data<float>(dim, platform::CPUPlace());
std::fill(data, data + tensor->numel(), val);
return var;
}
framework::LoDTensor& VarBase::Grad() {
VLOG(3) << "get var grad " << var_desc_->Name();
return *grads_->GetMutable<framework::LoDTensor>();
}
void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) {
VLOG(3) << "apply var grad " << var_desc_->Name() << " "
<< grad->Get<framework::LoDTensor>().data<float>()[0];
if (!grads_) {
grads_ =
CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 0.0, scope);
}
AddTo(grad, grads_);
VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " "
<< grads_->Get<framework::LoDTensor>().data<float>()[0];
}
std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VLOG(3) << "op grad " << grad_op_desc_->Type();
for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) {
if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) {
// grad op inputs can be forward inputs, so not in grad_to_var.
continue;
}
VLOG(3) << "op grad in var " << grad_invar;
block_->FindRecursiveOrCreateVar(grad_invar);
framework::Variable* var = scope->Var(grad_invar);
const std::string& invar = grad_to_var_->at(grad_invar);
for (VarBase* varbase : *output_vars_) {
// Use the accumulated grads_ by sharing the input with grads_.
if (varbase->var_desc_->Name() == invar) {
var->GetMutable<framework::LoDTensor>()->ShareDataWith(
varbase->grads_->Get<framework::LoDTensor>());
break;
}
}
}
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
VLOG(3) << "grad outvar " << outvar;
block_->FindRecursiveOrCreateVar(outvar);
framework::Variable* var = scope->Var(outvar);
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block_->FindVar(outvar);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
}
}
grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_);
opbase->Run(*scope, platform::CPUPlace());
// `ret` matches exactly with `input_vars_` of forward op.
std::vector<Variable*> ret;
for (size_t i = 0; i < input_vars_->size(); ++i) {
bool found = false;
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
Variable* var = scope->FindVar(outvar);
VarBase* origin_var = (*input_vars_)[i];
std::string orig_var = grad_to_var_->at(outvar);
PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var);
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var;
origin_var->ApplyGrad(scope, var);
found = true;
ret.push_back(var);
// TODO(panyx0718): There might be another outvar with the same name.
// In that case, it doesn't matter the first one or the second one is
// used.
break;
}
if (!found) {
ret.push_back(nullptr);
}
}
return ret;
}
void VarBase::RunBackward(framework::Scope* scope) {
grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 1.0, scope,
false);
if (!pre_op_) return;
Autograd(scope).RunBackward(this);
}
} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
class OpBase;
class VarBase {
public:
VarBase()
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
var_(nullptr),
grads_(nullptr) {}
virtual ~VarBase() {}
void ApplyGrad(framework::Scope* scope, framework::Variable* grad);
void RunBackward(framework::Scope* scope);
framework::LoDTensor& Grad();
OpBase* pre_op_;
int pre_op_out_idx_;
framework::VarDesc* var_desc_;
framework::Variable* var_;
framework::Variable* grads_;
};
class OpBase {
public:
OpBase()
: input_vars_(new std::vector<VarBase*>()),
output_vars_(new std::vector<VarBase*>()),
pre_ops_(new std::vector<OpBase*>()),
pre_ops_out_idx_(new std::vector<int>()),
op_desc_(nullptr),
grad_op_desc_(nullptr) {}
virtual ~OpBase() {
delete input_vars_;
delete output_vars_;
delete pre_ops_;
delete pre_ops_out_idx_;
if (grad_op_desc_) delete grad_op_desc_;
if (grad_to_var_) delete grad_to_var_;
}
std::vector<framework::Variable*> ApplyGrad(framework::Scope* scope);
std::vector<VarBase*>* input_vars_;
std::vector<VarBase*>* output_vars_;
std::vector<OpBase*>* pre_ops_;
std::vector<int>* pre_ops_out_idx_;
framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_;
std::unordered_map<std::string, std::string>* grad_to_var_;
framework::BlockDesc* block_;
};
class Layer {
public:
virtual ~Layer() {}
virtual std::vector<VarBase> Forward(const std::vector<VarBase>& inputs) {
std::vector<VarBase> vars;
return vars;
}
virtual void Backward() { LOG(ERROR) << "To support customize"; }
};
} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/tracer.h"
namespace paddle {
namespace imperative {} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle {
namespace imperative {
void CreateGradOp(const framework::OpDesc& op_desc,
const std::unordered_set<std::string>& no_grad_set,
const std::vector<framework::BlockDesc*>& grad_sub_block,
framework::OpDesc** grad_op_desc,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now.");
// TODO(panyx0718): Leak?
*grad_op_desc = grad_op_descs[0].release();
}
class Tracer {
public:
explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
root_scope_ = new framework::Scope();
scopes_[root_block_] = root_scope_;
}
virtual ~Tracer() { delete root_scope_; }
void Trace(OpBase* op, const std::vector<VarBase*>& inputs,
const std::vector<VarBase*>& outputs,
framework::BlockDesc* block) {
framework::Scope* scope = GetScope(block);
framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type();
op_desc->InferShape(*block);
op_desc->InferVarType(block);
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc);
*op->input_vars_ = inputs;
for (VarBase* input : inputs) {
const std::string vname = input->var_desc_->Name();
framework::Variable* var = scope->Var(vname);
input->var_ = var;
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block->FindVar(vname);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
}
if (input->pre_op_) {
op->pre_ops_->push_back(input->pre_op_);
op->pre_ops_out_idx_->push_back(input->pre_op_out_idx_);
} else {
op->pre_ops_->push_back(nullptr);
}
}
*op->output_vars_ = outputs;
for (size_t i = 0; i < outputs.size(); ++i) {
const std::string vname = outputs[i]->var_desc_->Name();
framework::Variable* var = scope->Var(vname);
if (!var->IsInitialized()) {
framework::VarDesc* var_desc = block->FindVar(vname);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
}
outputs[i]->var_ = var;
outputs[i]->pre_op_ = op;
outputs[i]->pre_op_out_idx_ = i;
}
op_base->Run(*scope, platform::CPUPlace());
framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
op->block_ = block;
}
framework::Scope* GetScope(framework::BlockDesc* block) {
if (scopes_.find(block) != scopes_.end()) {
return scopes_.at(block);
}
framework::BlockDesc* parent_block = block->ParentBlock();
PADDLE_ENFORCE(scopes_.find(parent_block) != scopes_.end());
framework::Scope* scope = &scopes_[parent_block]->NewScope();
scopes_[block] = scope;
return scope;
}
private:
std::map<framework::BlockDesc*, framework::Scope*> scopes_;
framework::BlockDesc* root_block_;
framework::Scope* root_scope_;
};
} // namespace imperative
} // namespace paddle
......@@ -103,6 +103,7 @@ struct Argument {
// Model specified with program and parameters files.
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
// The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
......@@ -46,7 +46,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
argument->model_params_path_valid()) {
auto program =
LoadModel(argument->model_program_path(), argument->model_params_path(),
argument->scope_ptr(), place);
argument->scope_ptr(), place, argument->model_from_memory());
argument->SetMainProgram(program.release());
} else {
PADDLE_THROW(
......@@ -68,9 +68,14 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &program_path, const std::string &params_path,
framework::Scope *scope, const platform::Place &place) {
framework::Scope *scope, const platform::Place &place,
bool model_from_memory) {
framework::Executor exe(place);
return Load(&exe, scope, program_path, params_path);
if (!model_from_memory) {
return Load(&exe, scope, program_path, params_path);
} else {
return LoadFromMemory(&exe, scope, program_path, params_path);
}
}
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; }
......
......@@ -24,7 +24,7 @@ namespace inference {
namespace analysis {
/*
* Load program and parameter to memory from the disk.
* Load program and parameter to memory from the disk or directly from memory.
*/
class IrGraphBuildPass : public AnalysisPass {
public:
......@@ -38,7 +38,8 @@ class IrGraphBuildPass : public AnalysisPass {
const platform::Place &place);
std::unique_ptr<framework::ProgramDesc> LoadModel(
const std::string &program_path, const std::string &params_path,
framework::Scope *scope, const platform::Place &place);
framework::Scope *scope, const platform::Place &place,
bool model_from_memory);
std::string model_binary_str_;
};
......
......@@ -53,6 +53,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
model_from_memory_ = other.model_from_memory_;
if (use_gpu) {
pass_builder_.reset(new GpuPassStrategy(
......@@ -80,6 +81,8 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
model_from_memory_ = other.model_from_memory_;
pass_builder_ = std::move(other.pass_builder_);
}
......@@ -102,4 +105,13 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
pass_builder()->InsertPass(1, "tensorrt_subgraph_pass");
}
void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
size_t prog_buffer_size,
const char *param_buffer,
size_t param_buffer_size) {
prog_file = std::string(prog_buffer, prog_buffer + prog_buffer_size);
param_file = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
} // namespace paddle
......@@ -308,6 +308,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetUseGPU(config_.use_gpu);
argument_.SetGPUDeviceId(config_.device);
argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program
if (!config_.model_dir.empty()) {
argument_.SetModelDir(config_.model_dir);
......@@ -448,20 +449,24 @@ bool AnalysisPredictor::LoadProgramDesc() {
return false;
}
std::string pb_content;
// Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
fin.seekg(0, std::ios::end);
pb_content.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(pb_content.at(0)), pb_content.size());
fin.close();
// Create ProgramDesc
framework::proto::ProgramDesc proto;
proto.ParseFromString(pb_content);
if (!config_.model_from_memory()) {
std::string pb_content;
// Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin.is_open()), "Cannot open file %s",
filename);
fin.seekg(0, std::ios::end);
pb_content.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(pb_content.at(0)), pb_content.size());
fin.close();
proto.ParseFromString(pb_content);
} else {
proto.ParseFromString(config_.prog_file);
}
inference_program_.reset(new framework::ProgramDesc(proto));
return true;
}
......@@ -469,6 +474,7 @@ bool AnalysisPredictor::LoadProgramDesc() {
bool AnalysisPredictor::LoadParameters() {
PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
"The inference program should be loaded first.");
const auto &global_block = inference_program_->MutableBlock(0);
// create a temporary program to load parameters.
......
......@@ -52,10 +52,13 @@ struct AnalysisConfig : public NativeConfig {
bool use_tensorrt() const { return use_tensorrt_; }
void EnableMKLDNN();
// NOTE this is just for internal development, please not use it.
// NOT stable yet.
bool use_mkldnn() const { return use_mkldnn_; }
// Specify the memory buffer of program and parameter
void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
const char* program_buffer, size_t program_buffer_size);
bool model_from_memory() const { return model_from_memory_; }
friend class ::paddle::AnalysisPredictor;
protected:
......@@ -64,6 +67,7 @@ struct AnalysisConfig : public NativeConfig {
int tensorrt_workspace_size_;
int tensorrt_max_batchsize_;
std::unique_ptr<PassStrategy> pass_builder_;
bool model_from_memory_{false};
};
// Configurations for Anakin engine.
......
......@@ -69,7 +69,8 @@ bool IsPersistable(const framework::VarDesc* var) {
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program,
const std::string& dirname,
const std::string& param_filename) {
const std::string& param_filename,
bool model_from_memory = false) {
const framework::BlockDesc& global_block = main_program.Block(0);
framework::ProgramDesc* load_program = new framework::ProgramDesc();
......@@ -108,6 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
op->SetType("load_combine");
op->SetOutput("Out", paramlist);
op->SetAttr("file_path", {param_filename});
op->SetAttr("model_from_memory", {model_from_memory});
op->CheckAttrs();
}
......@@ -130,16 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, dirname, "");
// model_from_memory is false in seperate parameters.
LoadPersistables(executor, scope, *main_program, dirname, "",
false /* model_from_memory */);
return main_program;
}
std::unique_ptr<framework::ProgramDesc> Load(
framework::Executor* executor, framework::Scope* scope,
const std::string& prog_filename, const std::string& param_filename) {
std::string model_filename = prog_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
ReadBinaryFile(prog_filename, &program_desc_str);
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
......@@ -147,7 +150,22 @@ std::unique_ptr<framework::ProgramDesc> Load(
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_filename);
LoadPersistables(executor, scope, *main_program, "", param_filename,
false /* model_from_memory */);
return main_program;
}
std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
framework::Executor* executor, framework::Scope* scope,
const std::string& prog_buffer, const std::string& param_buffer) {
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(prog_buffer));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_buffer,
true /* model_filename */);
return main_program;
}
......
......@@ -30,7 +30,8 @@ void Init(const std::vector<std::string> argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program,
const std::string& dirname,
const std::string& param_filename);
const std::string& param_filename,
bool model_from_memory);
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
framework::Scope* scope,
......@@ -41,6 +42,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const std::string& prog_filename,
const std::string& param_filename);
std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
framework::Executor* executor, framework::Scope* scope,
const std::string& prog_buffer, const std::string& param_buffer);
// Save the variables from a scope to disk.
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
......
......@@ -93,9 +93,17 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
void SetConfig(contrib::AnalysisConfig *cfg) {
cfg->prog_file = FLAGS_infer_model + "/__model__";
cfg->param_file = FLAGS_infer_model + "/param";
void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) {
if (memory_load) {
std::string buffer_prog, buffer_param;
ReadBinaryFile(FLAGS_infer_model + "/__model__", &buffer_prog);
ReadBinaryFile(FLAGS_infer_model + "/param", &buffer_param);
cfg->SetModelBuffer(&buffer_prog[0], buffer_prog.size(), &buffer_param[0],
buffer_param.size());
} else {
cfg->prog_file = FLAGS_infer_model + "/__model__";
cfg->param_file = FLAGS_infer_model + "/param";
}
cfg->use_gpu = false;
cfg->device = 0;
cfg->specify_input_name = true;
......@@ -114,9 +122,9 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
}
// Easy for profiling independently.
TEST(Analyzer_Chinese_ner, profile) {
void profile(bool memory_load = false) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
SetConfig(&cfg, memory_load);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
......@@ -138,6 +146,12 @@ TEST(Analyzer_Chinese_ner, profile) {
}
}
TEST(Analyzer_Chinese_ner, profile) { profile(); }
TEST(Analyzer_Chinese_ner, profile_memory_load) {
profile(true /* memory_load */);
}
// Check the fuse status
TEST(Analyzer_Chinese_ner, fuse_statis) {
contrib::AnalysisConfig cfg;
......
......@@ -49,8 +49,6 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
os << GenSpaces(num_spaces) << "device: " << config.device << "\n";
os << GenSpaces(num_spaces)
<< "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n";
os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
os << GenSpaces(num_spaces)
<< "specify_input_name: " << config.specify_input_name << "\n";
os << GenSpaces(num_spaces)
......@@ -65,6 +63,13 @@ std::ostream &operator<<(std::ostream &os,
os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n";
num_spaces++;
os << *reinterpret_cast<const NativeConfig *>(&config);
if (!config.model_from_memory()) {
os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
} else {
os << GenSpaces(num_spaces)
<< "prog_file and param_file: load from memory \n";
}
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.enable_ir_optim
<< "\n";
os << GenSpaces(num_spaces)
......
cc_library(benchmark SRCS benchmark.cc DEPS enforce)
cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
cc_binary(visualizer SRCS visualizer.cc DEPS analysis
paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
if(WIN32)
target_link_libraries(visualizer shlwapi)
endif(WIN32)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/utils/visualizer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <fstream>
#include <memory>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/platform/init.h"
DEFINE_string(model_dir, "", "model directory");
DEFINE_string(model_program_path, "", "model program path");
DEFINE_string(model_params_path, "", "model params path");
USE_PASS(graph_viz_pass);
USE_PASS(graph_to_program_pass);
using paddle::inference::analysis::Argument;
namespace paddle {
namespace inference {
namespace utils {
void Visualizer::SetArgument(Argument *argument) { argument_ = argument; }
bool Visualizer::Run() {
paddle::framework::InitDevices(false);
paddle::inference::analysis::Analyzer().Run(argument_);
return true;
}
} // namespace utils
} // namespace inference
} // namespace paddle
// Generate a dot file describing the structure of graph.
// To use this tool, run command: ./visualizer [options...]
// Options:
// --model_dir: the directory of model
// --model_program_path: the path of program
// --model_params_path: the path of params
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
paddle::inference::analysis::Argument argument;
argument.SetUseGPU(false);
argument.SetUseTensorRT(false);
if (FLAGS_model_dir.empty()) {
if (FLAGS_model_program_path.empty() || FLAGS_model_params_path.empty()) {
LOG(ERROR) << "Please set model_dir"
" or model_program_path and model_params_path";
return -1;
} else {
argument.SetModelProgramPath(FLAGS_model_program_path);
argument.SetModelParamsPath(FLAGS_model_params_path);
}
} else {
argument.SetModelDir(FLAGS_model_dir);
}
// Only 1 pass, default filename is 0_ir_origin.dot
// For more details, looking for paddle::inference::analysis::IRPassManager
argument.SetIrAnalysisPasses({"graph_viz_pass"});
std::unique_ptr<paddle::framework::Scope> scope{
new paddle::framework::Scope()};
argument.SetScopeNotOwned(
const_cast<paddle::framework::Scope *>(scope.get()));
paddle::inference::utils::Visualizer visualizer;
visualizer.SetArgument(&argument);
visualizer.Run();
return 0;
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace inference {
namespace utils {
using paddle::inference::analysis::Argument;
class Visualizer final {
public:
Visualizer() = default;
~Visualizer() = default;
Visualizer(const Visualizer &) = delete;
Visualizer &operator=(const Visualizer &) = delete;
void SetArgument(Argument *);
bool Run();
private:
Argument *argument_;
};
} // namespace utils
} // namespace inference
} // namespace paddle
......@@ -28,6 +28,46 @@ using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
bool is_conv3d) {
if (groups > 1) {
if (is_conv3d) {
int output = weights_tz[0];
int input = weights_tz[1];
int dimension = weights_tz[2];
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
}
}
inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
int groups, bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? format : mkldnn::memory::format::goidhw;
} else {
return (groups == 1) ? format : mkldnn::memory::format::goihw;
}
}
template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -52,10 +92,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
......@@ -71,9 +111,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
......@@ -83,18 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives
......@@ -105,11 +138,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline;
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -119,10 +155,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
......@@ -263,8 +305,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
......@@ -288,8 +330,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
......@@ -349,6 +391,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const T* output_grad_data = output_grad->data<T>();
......@@ -358,8 +401,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
......@@ -372,9 +421,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Create user memory descriptors
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
auto user_diff_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
......@@ -386,14 +435,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......@@ -500,3 +555,13 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
......@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.")
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......@@ -232,6 +232,10 @@ $$
}
void Conv3DOpMaker::Make() {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -247,6 +251,11 @@ void Conv3DOpMaker::Make() {
"is the width of the filter."
"If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups.");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
......@@ -280,6 +289,13 @@ void Conv3DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -32,16 +32,26 @@ class LoadCombineOp : public framework::OperatorBase {
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto load_as_fp16 = Attr<bool>("load_as_fp16");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
auto model_from_memory = Attr<bool>("model_from_memory");
auto out_var_names = Outputs("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
if (!model_from_memory) {
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
std::stringstream fin(filename);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
}
}
void LoadParamsFromBuffer(
const framework::Scope &scope, const platform::Place &place,
std::istream *buffer, bool load_as_fp16,
const std::vector<std::string> &out_var_names) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......@@ -54,11 +64,10 @@ class LoadCombineOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot read more from file %s",
filename);
PADDLE_ENFORCE(static_cast<bool>(buffer), "Cannot read more");
// Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx);
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type());
auto out_dtype =
......@@ -103,11 +112,17 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"LoDTensors will be loaded from \"file_path\".")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddAttr<bool>("model_from_memory",
"(boolean, default false)"
"If true, file_path is in memory, and LoDTensors will be "
"loaded directly from memory")
.SetDefault(false);
AddComment(R"DOC(
LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file. The file should
contain one or more LoDTensors serialized using the SaveCombine operator. The
LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
......
......@@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
return mkldnn::memory::format::x;
} else if (dims_size == 2) {
return mkldnn::memory::format::nc;
} else if (dims_size == 3) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::nwc;
}
} else if (dims_size == 5) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncdhw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::ndhwc;
}
}
return data_format;
}
......
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc)
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON)
if(WITH_AMD_GPU)
hip_library(paddle_pybind SHARED
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/tracer.h"
namespace paddle {
namespace pybind {
// Bind Methods
void BindTracer(pybind11::module *m) {
pybind11::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block) {
new (&self) imperative::Tracer(root_block);
})
.def("trace", &imperative::Tracer::Trace)
.def("get_scope", &imperative::Tracer::GetScope,
pybind11::return_value_policy::reference);
}
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
class PyLayer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
std::vector<imperative::VarBase> Forward(
const std::vector<imperative::VarBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VarBase>, Layer, Forward,
inputs); // NOLINT
}
void Backward() override {
PYBIND11_OVERLOAD(void, Layer, Backward, ); // NOLINT
}
};
class PyOpBase : public imperative::OpBase {
public:
using imperative::OpBase::OpBase; // Inherit constructors
};
class PyVarBase : public imperative::VarBase {
public:
using imperative::VarBase::VarBase; // Inherit constructors
};
void BindTracer(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
......@@ -45,6 +46,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/recordio.h"
......@@ -100,6 +102,42 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>())
.def("_run_backward",
[](imperative::VarBase &self, framework::Scope *scope) {
self.RunBackward(scope);
})
.def("_grad", &imperative::VarBase::Grad)
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },
[](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc;
},
py::return_value_policy::reference);
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>())
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
if (op_desc) {
self.op_desc_ = op_desc;
}
},
py::return_value_policy::reference);
py::class_<imperative::Layer, PyLayer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>())
.def("forward",
[](imperative::Layer &self,
const std::vector<imperative::VarBase> &inputs) {
return self.Forward(inputs);
})
.def("backward", &imperative::Layer::Backward);
BindTracer(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
......@@ -601,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
m.def("get_variable_tensor", framework::GetVariableTensor);
m.def("_is_program_version_supported", IsProgramVersionSupported);
......
......@@ -34,6 +34,7 @@ from . import io
from . import evaluator
from . import initializer
from . import layers
from . import imperative
from . import contrib
from . import nets
from . import optimizer
......@@ -67,6 +68,7 @@ __all__ = framework.__all__ + executor.__all__ + \
'initializer',
'layers',
'contrib',
'imperative',
'transpiler',
'nets',
'optimizer',
......
......@@ -18,6 +18,7 @@ import collections
import contextlib
import re
import six
import sys
import numpy as np
......@@ -49,6 +50,16 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_imperative_tracer_ = None
def _in_imperative_mode():
return _imperative_tracer_ is not None
def _imperative_tracer():
return _imperative_tracer_
class NameScope(object):
def __init__(self, name="", parent=None):
......@@ -202,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__()
class Variable(object):
class Variable(core.VarBase):
"""
In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training
......@@ -266,6 +277,7 @@ class Variable(object):
stop_gradient=False,
is_data=False,
**kwargs):
core.VarBase.__init__(self)
self.block = block
self.error_clip = error_clip
......@@ -346,6 +358,18 @@ class Variable(object):
self.stop_gradient = stop_gradient
self.is_data = is_data
def _numpy(self):
scope = _imperative_tracer().get_scope(self.block.desc)
tensor = core.get_variable_tensor(scope, self.desc.name())
return np.array(tensor)
def _backward(self):
scope = _imperative_tracer().get_scope(self.block.desc)
self._run_backward(scope)
def _gradient(self):
return np.array(self._grad())
def __str__(self):
return self.to_string(True)
......@@ -492,7 +516,7 @@ class OpProtoHolder(object):
}
class Operator(object):
class Operator(core.OpBase):
"""
In Fluid, all the operation are represented by Operator, and Operator
is regarded as a build in an instruction of a Block. Users can use the
......@@ -548,6 +572,7 @@ class Operator(object):
inputs=None,
outputs=None,
attrs=None):
core.OpBase.__init__(self)
self.block = block
self.desc = desc
# note: not add self.attrs here:
......@@ -587,6 +612,7 @@ class Operator(object):
return True
return False
self.inputs = []
if inputs is not None:
for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name)
......@@ -613,6 +639,13 @@ class Operator(object):
else:
self.desc.set_input(in_proto.name, [])
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None:
given = set()
need = set()
......@@ -641,6 +674,12 @@ class Operator(object):
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
for out in outputs.values():
if isinstance(out, Variable):
self.outputs.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
self.outputs.extend(out[:])
if op_attrs is not None:
if not isinstance(op_attrs, dict):
raise TypeError("'attrs' should be a dict.")
......@@ -1206,6 +1245,8 @@ class Block(object):
"""
op_desc = self.desc.append_op()
op = Operator(block=self, desc=op_desc, *args, **kwargs)
if _in_imperative_mode():
_imperative_tracer().trace(op, op.inputs, op.outputs, self.desc)
self.ops.append(op)
return op
......@@ -2209,3 +2250,12 @@ def _get_var(name, program=None):
assert isinstance(program, Program)
return program.global_block().var(name)
@contextlib.contextmanager
def _imperative_guard(tracer):
global _imperative_tracer_
tmp_trace = _imperative_tracer_
_imperative_tracer_ = tracer
yield
_imperative_tracer_ = tmp_trace
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import base
from .base import *
from . import layers
from .layers import *
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
__all__ = ['enabled', 'guard', 'to_variable']
def enabled():
return framework._in_imperative_mode()
@contextlib.contextmanager
def guard():
train = framework.Program()
startup = framework.Program()
tracer = core.Tracer(train.current_block().desc)
with framework.program_guard(train, startup):
with framework.unique_name.guard():
with framework._imperative_guard(tracer):
yield
def to_variable(value, block=None):
if isinstance(value, np.ndarray):
if not block:
block = framework.default_main_program().current_block()
py_var = framework.Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=value.shape,
dtype=value.dtype)
scope = framework._imperative_tracer().get_scope(block.desc)
var = scope.var(py_var.name)
tensor = var.get_tensor()
tensor.set(value, core.CPUPlace())
return py_var
elif isinstance(value, framework.Variable):
return value
else:
raise ValueError("Unsupported type %s" % type(value))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import sys
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.imperative import base
__all__ = ['PyLayer']
class PyLayer(core.Layer):
def __init__(self):
pass
def __call__(self, inputs):
# TODO(panyx0718): Support declarative mode as well.
assert base.enabled()
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
inputs = [inputs]
var_inputs = []
for x in inputs:
py_var = base.to_variable(x)
var_inputs.append(py_var)
outputs = self.forward(var_inputs)
return outputs
def forward(self, inputs):
return []
......@@ -17,10 +17,13 @@ from __future__ import print_function
import copy
import itertools
import six
import sys
import numpy as np
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating
from . import unique_name
from paddle.fluid.initializer import Constant, Xavier
from paddle.fluid.imperative import base
from .param_attr import ParamAttr, WeightNormParamAttr
from . import core
from six.moves import zip
......@@ -46,23 +49,21 @@ class LayerHelper(object):
def startup_program(self):
return default_startup_program()
def to_variable(self, x):
return base.to_variable(x, self.main_program.current_block())
def append_op(self, *args, **kwargs):
return self.main_program.current_block().append_op(*args, **kwargs)
def multiple_input(self, input_param_name='input'):
inputs = self.kwargs.get(input_param_name, [])
type_error = TypeError(
"Input of {0} layer should be Variable or sequence of Variable".
format(self.layer_type))
if isinstance(inputs, Variable):
inputs = [inputs]
elif not isinstance(inputs, list) and not isinstance(inputs, tuple):
raise type_error
ret = []
if isinstance(inputs, list) or isinstance(inputs, tuple):
for inp in inputs:
ret.append(self.to_variable(inp))
else:
for each in inputs:
if not isinstance(each, Variable):
raise type_error
return inputs
ret.append(self.to_variable(inputs))
return ret
def input(self, input_param_name='input'):
inputs = self.multiple_input(input_param_name)
......
......@@ -6636,7 +6636,8 @@ def relu(x, name=None):
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out})
helper.append_op(
type="relu", inputs={"X": helper.input('x')}, outputs={"Out": out})
return out
......
......@@ -102,7 +102,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
if args.is_dist:
if args.update_method == "pserver":
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
......@@ -147,7 +147,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
def get_data():
origin_batch = next(reader_generator)
if args.is_dist and args.use_reader_alloc:
if args.update_method == "pserver" and args.use_reader_alloc:
new_batch = []
for offset, item in enumerate(origin_batch):
if offset % 2 == args.trainer_id:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1
class TestMKLDNN(TestConv3dOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNCase1(TestCase1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup1(TestWithGroup1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup2(TestWithGroup2):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__':
unittest.main()
......@@ -74,6 +74,8 @@ class TestConv3dOp(OpTest):
def setUp(self):
self.op_type = "conv3d"
self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
......@@ -83,8 +85,7 @@ class TestConv3dOp(OpTest):
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
......@@ -101,7 +102,9 @@ class TestConv3dOp(OpTest):
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
self.outputs = {'Output': output}
......@@ -109,59 +112,35 @@ class TestConv3dOp(OpTest):
return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self):
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.03)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0, 0]
......
......@@ -76,12 +76,24 @@ class TestDistRunnerBase(object):
if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
if args.is_dist:
if args.update_method == "pserver":
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
args.sync_mode, args.dc_asgd)
trainer_prog = t.get_trainer_program()
elif args.update_method == "nccl2":
# transpile for nccl2
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
nccl2_t = fluid.DistributeTranspiler(config=config)
nccl2_t.transpile(
args.trainer_id,
program=fluid.default_main_program(),
startup_program=fluid.default_startup_program(),
trainers=args.endpoints,
current_endpoint=args.current_endpoint)
trainer_prog = fluid.default_main_program()
else:
trainer_prog = fluid.default_main_program()
......@@ -110,11 +122,20 @@ class TestDistRunnerBase(object):
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2":
num_trainers = len(args.endpoints.split(","))
trainer_id = args.trainer_id
else:
num_trainers = 1
trainer_id = 0
exe = fluid.ParallelExecutor(
args.use_cuda,
loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra)
build_strategy=build_stra,
num_trainers=num_trainers,
trainer_id=trainer_id)
feed_var_list = [
var for var in trainer_prog.global_block().vars.values()
......@@ -126,7 +147,7 @@ class TestDistRunnerBase(object):
def get_data():
origin_batch = next(reader_generator)
if args.is_dist and args.use_reader_alloc:
if args.update_method != "local" and args.use_reader_alloc:
new_batch = []
for offset, item in enumerate(origin_batch):
if offset % 2 == args.trainer_id:
......@@ -151,7 +172,11 @@ def runtime_main(test_class):
parser.add_argument(
'--role', type=str, required=True, choices=['pserver', 'trainer'])
parser.add_argument('--endpoints', type=str, required=False, default="")
parser.add_argument('--is_dist', action='store_true')
parser.add_argument(
'--update_method',
type=str,
default="local",
choices=["pserver", "nccl2", "local"])
parser.add_argument('--trainer_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument(
......@@ -170,7 +195,7 @@ def runtime_main(test_class):
args = parser.parse_args()
model = test_class()
if args.role == "pserver" and args.is_dist:
if args.role == "pserver" and args.update_method == "pserver":
model.run_pserver(args)
else:
model.run_trainer(args)
......@@ -208,6 +233,7 @@ class TestDistBase(unittest.TestCase):
self._use_reduce = False
self._dc_asgd = False # must use with async mode
self._use_reader_alloc = True
self._nccl2_mode = False
self._setup_config()
self._after_setup_config()
......@@ -218,7 +244,7 @@ class TestDistBase(unittest.TestCase):
def start_pserver(self, model_file, check_error_log, required_envs):
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
ps0_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers)
......@@ -270,7 +296,8 @@ class TestDistBase(unittest.TestCase):
else:
env_local = {'CPU_NUM': '1'}
envs.update(env_local)
env_local.update(envs)
print("local_cmd: {}, env: {}".format(cmd, env_local))
if check_error_log:
err_log = open("/tmp/trainer.err.log", "wb")
......@@ -278,13 +305,13 @@ class TestDistBase(unittest.TestCase):
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=err_log,
env=envs)
env=env_local)
else:
local_proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=envs)
env=env_local)
local_out, local_err = local_proc.communicate()
......@@ -303,7 +330,7 @@ class TestDistBase(unittest.TestCase):
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver"
tr0_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
0, ps0_ep, self._trainers)
......@@ -335,8 +362,8 @@ class TestDistBase(unittest.TestCase):
env0.update(envs)
env1.update(envs)
print("tr0_cmd:{}".format(tr0_cmd))
print("tr1_cmd:{}".format(tr1_cmd))
print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb")
......@@ -357,12 +384,9 @@ class TestDistBase(unittest.TestCase):
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
ps0_pipe.close()
ps1_pipe.close()
# FIXME: use terminate() instead of sigkill.
os.kill(ps0.pid, signal.SIGKILL)
os.kill(ps1.pid, signal.SIGKILL)
ps0.terminate()
ps1.terminate()
......@@ -372,7 +396,71 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
# return tr0_losses, tr1_losses
return pickle.loads(tr0_out), pickle.loads(tr1_out)
def _run_cluster_nccl2(self, model, envs, check_error_log):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2"
tr0_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
0, w0_ep)
tr1_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
1, w1_ep)
if self._mem_opt:
tr0_cmd += " --mem_opt"
tr1_cmd += " --mem_opt"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
if self._use_reader_alloc:
tr0_cmd += " --use_reader_alloc"
tr1_cmd += " --use_reader_alloc"
if self.__use_cuda:
tr0_cmd += " --use_cuda"
tr1_cmd += " --use_cuda"
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
else:
env0 = {'CPU_NUM': '1'}
env1 = {'CPU_NUM': '1'}
env0.update(envs)
env1.update(envs)
print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1)
tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate()
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
# print log
sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)
return pickle.loads(tr0_out), pickle.loads(tr1_out)
def check_with_place(self,
......@@ -387,20 +475,25 @@ class TestDistBase(unittest.TestCase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_cudnn_deterministic": "1",
"http_proxy": ""
"http_proxy": "",
"NCCL_P2P_DISABLE": "1"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "7"
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
local_losses\
= self._run_local(model_file, required_envs,
check_error_log)
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs,
check_error_log)
if self._nccl2_mode:
tr0_losses, tr1_losses = self._run_cluster_nccl2(
model_file, required_envs, check_error_log)
else:
tr0_losses, tr1_losses = self._run_cluster(
model_file, required_envs, check_error_log)
for step_id in range(RUN_STEP):
local_loss = local_losses[step_id]
......
......@@ -26,6 +26,19 @@ class TestDistMnist2x2(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnistNCCL2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1)
class TestDistMnist2x2Lars(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......
......@@ -44,7 +44,7 @@ class TestDistSaveLoadDense2x2(TestDistBase):
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "7"
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
model_dir = tempfile.mkdtemp()
......
......@@ -769,6 +769,7 @@ class TestNCCL2Transpile(TranspilerTest):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
config.wait_port = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(
0,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
class MyLayer(fluid.imperative.PyLayer):
def __init__(self):
super(MyLayer, self).__init__()
def forward(self, inputs):
x = fluid.layers.relu(inputs[0])
self._x_for_debug = x
return [fluid.layers.elementwise_mul(x, x)]
class TestImperative(unittest.TestCase):
def test_layer(self):
with fluid.imperative.guard():
cl = core.Layer()
cl.forward([])
l = fluid.imperative.PyLayer()
l.forward([])
def test_layer_in_out(self):
with fluid.imperative.guard():
l = MyLayer()
x = l(np.array([1.0, 2.0, -1.0], dtype=np.float32))[0]
self.assertIsNotNone(x)
sys.stderr.write("%s output: %s\n" % (x, x._numpy()))
x._backward()
sys.stderr.write("grad %s\n" % l._x_for_debug._gradient())
if __name__ == '__main__':
unittest.main()
......@@ -142,6 +142,7 @@ class DistributeTranspilerConfig(object):
# supported modes: pserver, nccl2
mode = "pserver"
print_log = False
wait_port = True
class DistributeTranspiler(object):
......@@ -171,7 +172,6 @@ class DistributeTranspiler(object):
trainer_id = 0
trainers = 4
role = os.getenv("PADDLE_TRAINING_ROLE")
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers)
......@@ -214,13 +214,16 @@ class DistributeTranspiler(object):
trainer_id,
trainers,
current_endpoint,
startup_program=None):
startup_program=None,
wait_port=True):
if not startup_program:
startup_program = default_startup_program()
if trainer_id >= 0:
worker_endpoints = trainers.split(",")
# send NCCL_ID to others or recv from trainer 0
worker_endpoints.remove(current_endpoint)
if trainer_id == 0 and wait_port:
wait_server_ready(worker_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
......@@ -306,7 +309,8 @@ class DistributeTranspiler(object):
trainer_id,
trainers,
current_endpoint,
startup_program=startup_program)
startup_program=startup_program,
wait_port=self.config.wait_port)
return
self.trainer_num = trainers
......@@ -652,9 +656,6 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
sys.stderr.write("get_pserver_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup \
in a single call.")
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
......@@ -922,18 +923,6 @@ in a single call.")
Returns:
Program: parameter server side startup program.
"""
sys.stderr.write("get_startup_program() is deprecated, call \
get_pserver_programs() to get pserver main and startup \
in a single call.")
if pserver_program != None:
sys.stderr.write("passing pserver_program to get_startup_program() \
is deprecated, you can use new API get_pserver_programs() to \
get both pserver main program and startup program.")
if startup_program != None:
sys.stderr.write("passing startup_program to get_startup_program() \
is deprecated, use fluid.program_guard() or pass this argument \
to transpile() call.")
s_prog = Program()
orig_s_prog = self.startup_program
s_prog.random_seed = orig_s_prog.random_seed
......
......@@ -101,6 +101,7 @@ packages=['paddle',
'paddle.dataset',
'paddle.reader',
'paddle.fluid',
'paddle.fluid.imperative',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.layers',
......
......@@ -27,6 +27,8 @@ import pydoc
member_dict = collections.OrderedDict()
experimental_namespace = {"paddle.fluid.imperative"}
def visit_member(parent_name, member):
cur_name = ".".join([parent_name, member.__name__])
......@@ -51,6 +53,8 @@ def visit_member(parent_name, member):
def visit_all_module(mod):
if (mod.__name__ in experimental_namespace):
return
for member_name in (
name
for name in (mod.__all__ if hasattr(mod, "__all__") else dir(mod))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册