提交 57033869 编写于 作者: M minqiyang

Add debug info

上级 202b2f1f
...@@ -26,17 +26,46 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -26,17 +26,46 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
scope_(scope), scope_(scope),
place_(place) {} place_(place) {}
struct RecordTime {
RecordTime(const std::string &name, const std::string &type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
~RecordTime() {
if (type_ == "elementsize_add") {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); {
RecordTime rt("ComputationOpHandle::RunImpl", "Wait");
WaitInputVarGenerated(place_);
}
Scope *scope = nullptr;
{
RecordTime rt("ComputationOpHandle::RunImpl", "PrepareScope");
scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
{
RecordTime rt("ComputationOpHandle::RunImpl", "ReallyRun " + op_->Type());
auto run_func = [this]() { auto run_func = [this, scope]() { op_->Run(*scope, place_); };
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) { if (is_lock_and_record_event_free_) {
run_func(); run_func();
} else { } else {
this->RunAndRecordEvent(run_func); this->RunAndRecordEvent(run_func);
}
} }
} }
......
...@@ -120,6 +120,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -120,6 +120,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_.get(), &fetch_ops);
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, OpHandleBase *op,
......
...@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() { ...@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda) { if (events_.empty() && use_cuda && !dev_ctxes_.empty()) {
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id)); PADDLE_ENFORCE(cudaSetDevice(dev_id));
......
...@@ -701,85 +701,125 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -701,85 +701,125 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
struct RecordTime {
RecordTime(const std::string& name, const std::string& type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
void inline stop() {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
~RecordTime() {
if (type_ == "elementwise_add") {
stop();
}
// stop();
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RecordTime rt("OperatorWithKernel::All", type_);
this->InferShape(&infer_shape_ctx); {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); RecordTime rt("OperatorWithKernel::InferShape", type_);
auto* dev_ctx = pool.Get(place); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
} }
OpKernelMap& kernels = kernels_iter->second; {
RecordTime* rt_1 = new RecordTime("OperatorWithKernel::Compute1", type_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the // check if op[type] has kernel registered.
// transform functions are ready. auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
type_);
}
// for (auto& candidate : kKernelPriority) { OpKernelMap& kernels = kernels_iter->second;
// Do selection
// }
auto expected_kernel_key = // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); // transform functions are ready.
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); // for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() && if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain; expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
if (kernel_iter == kernels.end()) { if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_, PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = Scope* transfer_scope = nullptr;
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars); // auto* transfer_scope =
// TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope = scope;
(transfer_scope == nullptr ? scope : *transfer_scope); // const Scope& exec_scope =
// (transfer_scope == nullptr ? scope : *transfer_scope);
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
delete rt_1;
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); RecordTime* rt_2 = new RecordTime("OperatorWithKernel::Compute2", type_);
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
delete rt_2;
if (!transfered_inplace_vars.empty()) { RecordTime* rt_3 = new RecordTime("OperatorWithKernel::Compute3", type_);
// there is inplace variable has been transfered. if (!transfered_inplace_vars.empty()) {
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope); // there is inplace variable has been transfered.
} TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
}
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) { for (auto& vname : OutputVars(true)) {
auto* var = exec_scope.FindVar(vname); auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue; if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>()); CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value()); CheckTensorNANOrInf(vname,
var->Get<framework::SelectedRows>().value());
}
} }
} }
delete rt_3;
} }
} }
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
......
...@@ -43,9 +43,16 @@ DEFINE_double( ...@@ -43,9 +43,16 @@ DEFINE_double(
// the mutex will cause serious performance issue. // the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`. // So the mutex is disabled when `ON_INFER`.
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
#define SCOPE_LOCK_GUARD #define SCOPE_READER_LOCK
#define SCOPE_WRITER_LOCK
#else #else
#define SCOPE_LOCK_GUARD std::lock_guard<std::mutex> lock(mutex_); // TODO(minqiyang): use reader lock and writer lock in all platforms
#define SCOPE_READER_LOCK
#define SCOPE_WRITER_LOCK
// #define SCOPE_READER_LOCK boost::shared_lock<boost::shared_mutex>
// lock(mutex_);
// #define SCOPE_WRITER_LOCK boost::unique_lock<boost::shared_mutex>
// lock(mutex_);
#endif #endif
namespace paddle { namespace paddle {
...@@ -61,18 +68,18 @@ int64_t GetEagerDeletionThreshold() { ...@@ -61,18 +68,18 @@ int64_t GetEagerDeletionThreshold() {
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
return VarInternal(name); return VarInternal(name);
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = new_name; *name = new_name;
...@@ -81,34 +88,34 @@ Variable* Scope::Var(std::string* name) { ...@@ -81,34 +88,34 @@ Variable* Scope::Var(std::string* name) {
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
SCOPE_LOCK_GUARD SCOPE_READER_LOCK
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::FindLocalVar(const std::string& name) const { Variable* Scope::FindLocalVar(const std::string& name) const {
SCOPE_LOCK_GUARD SCOPE_READER_LOCK
return FindVarLocally(name); return FindVarLocally(name);
} }
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
SCOPE_LOCK_GUARD SCOPE_READER_LOCK
return FindScopeInternal(var); return FindScopeInternal(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
bool Scope::HasKid(const Scope* scope) const { bool Scope::HasKid(const Scope* scope) const {
SCOPE_LOCK_GUARD SCOPE_READER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end(); return it != this->kids_.end();
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
SCOPE_LOCK_GUARD SCOPE_READER_LOCK
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size()); known_vars.reserve(this->vars_.size());
for (auto& p : vars_) { for (auto& p : vars_) {
...@@ -118,7 +125,7 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -118,7 +125,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "%p Cannot find %p as kid scope", PADDLE_ENFORCE(it != this->kids_.end(), "%p Cannot find %p as kid scope",
this, scope); this, scope);
...@@ -132,7 +139,7 @@ void Scope::DeleteScope(Scope* scope) const { ...@@ -132,7 +139,7 @@ void Scope::DeleteScope(Scope* scope) const {
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
...@@ -145,12 +152,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) { ...@@ -145,12 +152,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
} }
std::string Scope::Rename(const std::string& origin_name) const { std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_LOCK_GUARD SCOPE_WRITER_LOCK
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
return new_name; return new_name;
......
...@@ -33,34 +33,37 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -33,34 +33,37 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), if (!ctx->IsRuntime()) {
"Input(X) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(X) of elementwise op should not be null.");
"Input(Y) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"),
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Input(Y) of elementwise op should not be null.");
"Output(Out) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() == PADDLE_ENFORCE(ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR, framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s [%s]", "The input var's type should be LoDTensor, but the "
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front()); "received is %s [%s]",
ctx->GetInputsVarType("Y").front(),
if (ctx->GetInputsVarType("X").front() == ctx->Inputs("Y").front());
framework::proto::VarType::LOD_TENSOR) {
auto x_dim = ctx->GetInputDim("X"); if (ctx->GetInputsVarType("X").front() ==
auto y_dim = ctx->GetInputDim("Y"); framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), auto x_dim = ctx->GetInputDim("X");
"Rank of first input must >= rank of second input."); auto y_dim = ctx->GetInputDim("Y");
} else if (ctx->GetInputsVarType("X").front() == PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
framework::proto::VarType::SELECTED_ROWS) { "Rank of first input must >= rank of second input.");
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) && } else if (ctx->GetInputsVarType("X").front() ==
(ctx->GetInputDim("Y")[0] == 1), framework::proto::VarType::SELECTED_ROWS) {
"For elementwise_op, if X is Sparse, " PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
"Y must be scalar."); (ctx->GetInputDim("Y")[0] == 1),
} else { "For elementwise_op, if X is Sparse, "
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.", "Y must be scalar.");
ctx->GetInputsVarType("X").front()); } else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
} }
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
...@@ -125,7 +128,7 @@ The equation is: ...@@ -125,7 +128,7 @@ The equation is:
$$%s$$ $$%s$$
- $X$: a tensor of any dimension. - $X$: a tensor of any dimension.
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$. - $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
...@@ -135,10 +138,10 @@ There are two cases for this operator: ...@@ -135,10 +138,10 @@ There are two cases for this operator:
For case 2: For case 2:
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$. for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$. 2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of 3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2). subsequence, such as shape(Y) = (2, 1) => (2).
For example: For example:
...@@ -152,7 +155,7 @@ For example: ...@@ -152,7 +155,7 @@ For example:
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information. The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the input $X$. But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
......
...@@ -23,56 +23,57 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -23,56 +23,57 @@ class AdamOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), // PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamOp should not be null."); // "Input(Param) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), // PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamOp should not be null."); // "Input(Grad) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment1"), // PADDLE_ENFORCE(ctx->HasInput("Moment1"),
"Input(Moment1) of AdamOp should not be null."); // "Input(Moment1) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment2"), // PADDLE_ENFORCE(ctx->HasInput("Moment2"),
"Input(Moment2) of AdamOp should not be null."); // "Input(Moment2) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"), // PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamOp should not be null."); // "Input(LearningRate) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), // PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamOp should not be null."); // "Input(Beta1Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), // PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
"Input(Beta2Pow) of AdamOp should not be null."); // "Input(Beta2Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), // PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamOp should not be null."); // "Output(ParamOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), // PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
"Output(Moment1Out) of AdamOp should not be null."); // "Output(Moment1Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), // PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
"Output(Moment2Out) of AdamOp should not be null."); // "Output(Moment2Out) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, // PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension"); // "Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, // PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension"); // "Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1, // PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta2 power accumulator should have 1 dimension"); // "Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param"); auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] == // if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) { // framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"), // param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension"); // "Param and Grad input of AdamOp should have same dimension");
} // }
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"), // param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension"); // "Param and Moment1 input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"), // param_dims, ctx->GetInputDim("Moment2"),
"Param and Moment2 input of AdamOp should have same dimension"); // "Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
......
...@@ -92,7 +92,8 @@ def cuda_profiler(output_file, output_mode=None, config=None): ...@@ -92,7 +92,8 @@ def cuda_profiler(output_file, output_mode=None, config=None):
config_file = 'nvprof_config_file' config_file = 'nvprof_config_file'
with open(config_file, 'wb') as fp: with open(config_file, 'wb') as fp:
fp.writelines([six.b("%s\n" % item) for item in config]) fp.writelines([six.b("%s\n" % item) for item in config])
core.nvprof_init(output_file, output_mode, config_file) #Comment this for nvprof
#core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool. # Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start() core.nvprof_start()
yield yield
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册