提交 b75bd29c 编写于 作者: M minqiyang

Remove debug info

上级 7a43e517
...@@ -26,46 +26,17 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -26,46 +26,17 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
scope_(scope), scope_(scope),
place_(place) {} place_(place) {}
struct RecordTime {
RecordTime(const std::string &name, const std::string &type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
~RecordTime() {
if (type_ == "elementsize_add") {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
{ WaitInputVarGenerated(place_);
RecordTime rt("ComputationOpHandle::RunImpl", "Wait");
WaitInputVarGenerated(place_);
}
Scope *scope = nullptr;
{
RecordTime rt("ComputationOpHandle::RunImpl", "PrepareScope");
scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
{
RecordTime rt("ComputationOpHandle::RunImpl", "ReallyRun " + op_->Type());
auto run_func = [this, scope]() { op_->Run(*scope, place_); }; auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) { if (is_lock_and_record_event_free_) {
run_func(); run_func();
} else { } else {
this->RunAndRecordEvent(run_func); this->RunAndRecordEvent(run_func);
}
} }
} }
......
...@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() { ...@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && !dev_ctxes_.empty()) { if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id)); PADDLE_ENFORCE(cudaSetDevice(dev_id));
......
...@@ -20,6 +20,10 @@ limitations under the License. */ ...@@ -20,6 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
DEFINE_bool(enforce_when_check_program, true,
"Checking whether the program is correct or not. We will log "
"errors rather than throwing exceptions if this flag turned off");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -28,55 +32,85 @@ namespace { ...@@ -28,55 +32,85 @@ namespace {
void CheckProgram(const ProgramDesc &program) { void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role) #define _INT(role) static_cast<int>(role)
// std::map<int, bool> visit; std::map<int, bool> visit;
// for (OpDesc *op : program.Block(0).AllOps()) { for (OpDesc *op : program.Block(0).AllOps()) {
// // For backward compatibility, some program doesn't have role added. // For backward compatibility, some program doesn't have role added.
// if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue; if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
// int role_id = int role_id =
// boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
// visit[role_id] = true; visit[role_id] = true;
// switch (role_id) { switch (role_id) {
// case _INT(OpRole::kForward): case _INT(OpRole::kForward):
// if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
// LOG(ERROR) LOG(ERROR)
// << "Cannot add backward operator before forward operator %s." << "Cannot add backward operator before forward operator %s."
// << op->Type(); << op->Type();
// } }
// break; break;
// case _INT(OpRole::kBackward): case _INT(OpRole::kBackward):
// case _INT(OpRole::kBackward) | _INT(OpRole::kLoss): case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE( if (!FLAGS_enforce_when_check_program) {
// visit.find(_INT(OpRole::kOptimize)) == visit.end(), PADDLE_ENFORCE(
// "Cannot add backward operator %s after optimize operator.", visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// op->Type()); "Cannot add backward operator %s after optimize operator.",
// break; op->Type());
// case _INT(OpRole::kForward) | _INT(OpRole::kLoss): } else {
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) | if (visit.find(_INT(OpRole::kOptimize)) != visit.end()) {
// _INT(OpRole::kLoss)) == visit.end(), LOG(ERROR)
// "Cannot add backward|loss operator before " << "Cannot add backward operator %s after optimize operator.",
// "forward|loss operator %s.", << op->Type();
// op->Type()); }
// PADDLE_ENFORCE( }
// visit.find(_INT(OpRole::kOptimize)) == visit.end(), break;
// "Cannot add forward|loss operator %s after optimize operator.", case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
// op->Type()); if (!FLAGS_enforce_when_check_program) {
// break; PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
// case _INT(OpRole::kOptimize): _INT(OpRole::kLoss)) == visit.end(),
// case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched): "Cannot add backward|loss operator before "
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(), "forward|loss operator %s.",
// "Optimize operators %s must follow backward operator.", op->Type());
// op->Type()); PADDLE_ENFORCE(
// break; visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// case _INT(OpRole::kLRSched): "Cannot add forward|loss operator %s after optimize operator.",
// case _INT(OpRole::kDist): op->Type());
// case _INT(OpRole::kRPC): } else {
// case _INT(OpRole::kNotSpecified): if (visit.find(_INT(OpRole::kBackward) | _INT(OpRole::kLoss)) !=
// break; visit.end()) {
// default: LOG(ERROR) << "Cannot add backward|loss operator before "
// LOG(FATAL) << "Unknown operator role. Don't add new role because " << "forward|loss operator %s." << op->Type();
// "you don't know what you are doing."; }
// }
// } if (visit.find(_INT(OpRole::kOptimize)) != visit.end()) {
LOG(ERROR) << "Cannot add forward|loss operator %s after optimize "
"operator.",
<< op->Type();
}
}
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
if (!FLAGS_enforce_when_check_program) {
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.",
op->Type());
} else {
if (visit.find(_INT(OpRole::kBackward)) == visit.end()) {
LOG(ERROR)
<< "Optimize operators %s must follow backward operator.",
<< op->Type();
}
}
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
#undef _INT #undef _INT
} }
......
...@@ -701,125 +701,85 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -701,125 +701,85 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
struct RecordTime {
RecordTime(const std::string& name, const std::string& type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
void inline stop() {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
~RecordTime() {
if (type_ == "elementwise_add") {
stop();
}
// stop();
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RecordTime rt("OperatorWithKernel::All", type_); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
{ this->InferShape(&infer_shape_ctx);
RecordTime rt("OperatorWithKernel::InferShape", type_); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
RuntimeInferShapeContext infer_shape_ctx(*this, scope); auto* dev_ctx = pool.Get(place);
this->InferShape(&infer_shape_ctx);
}
{
RecordTime* rt_1 = new RecordTime("OperatorWithKernel::Compute1", type_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) { if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW( PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.", type_);
type_); }
}
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready. // transform functions are ready.
// for (auto& candidate : kKernelPriority) { // for (auto& candidate : kKernelPriority) {
// Do selection // Do selection
// } // }
auto expected_kernel_key = auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() && if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain; expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
if (kernel_iter == kernels.end()) { if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_, PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr; auto* transfer_scope =
// auto* transfer_scope = TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = scope; const Scope& exec_scope =
// const Scope& exec_scope = (transfer_scope == nullptr ? scope : *transfer_scope);
// (transfer_scope == nullptr ? scope : *transfer_scope);
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
delete rt_1;
RecordTime* rt_2 = new RecordTime("OperatorWithKernel::Compute2", type_); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
delete rt_2;
RecordTime* rt_3 = new RecordTime("OperatorWithKernel::Compute3", type_); if (!transfered_inplace_vars.empty()) {
if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered.
// there is inplace variable has been transfered. TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope); }
}
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) { for (auto& vname : OutputVars(true)) {
auto* var = exec_scope.FindVar(vname); auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue; if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>()); CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
var->Get<framework::SelectedRows>().value());
}
} }
} }
delete rt_3;
} }
} }
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
......
...@@ -33,37 +33,34 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -33,37 +33,34 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
if (!ctx->IsRuntime()) { PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of elementwise op should not be null.");
"Input(X) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"),
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of elementwise op should not be null.");
"Input(Y) of elementwise op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"),
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of elementwise op should not be null.");
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
PADDLE_ENFORCE(ctx->GetInputsVarType("Y").front() == ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR, framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the " "The input var's type should be LoDTensor, but the received is %s [%s]",
"received is %s [%s]", ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
ctx->GetInputsVarType("Y").front(),
ctx->Inputs("Y").front()); if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) {
if (ctx->GetInputsVarType("X").front() == auto x_dim = ctx->GetInputDim("X");
framework::proto::VarType::LOD_TENSOR) { auto y_dim = ctx->GetInputDim("Y");
auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
auto y_dim = ctx->GetInputDim("Y"); "Rank of first input must >= rank of second input.");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), } else if (ctx->GetInputsVarType("X").front() ==
"Rank of first input must >= rank of second input."); framework::proto::VarType::SELECTED_ROWS) {
} else if (ctx->GetInputsVarType("X").front() == PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
framework::proto::VarType::SELECTED_ROWS) { (ctx->GetInputDim("Y")[0] == 1),
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) && "For elementwise_op, if X is Sparse, "
(ctx->GetInputDim("Y")[0] == 1), "Y must be scalar.");
"For elementwise_op, if X is Sparse, " } else {
"Y must be scalar."); PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
} else { ctx->GetInputsVarType("X").front());
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
} }
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
...@@ -128,7 +125,7 @@ The equation is: ...@@ -128,7 +125,7 @@ The equation is:
$$%s$$ $$%s$$
- $X$: a tensor of any dimension. - $X$: a tensor of any dimension.
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$. - $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
...@@ -138,10 +135,10 @@ There are two cases for this operator: ...@@ -138,10 +135,10 @@ There are two cases for this operator:
For case 2: For case 2:
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$. for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$. 2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of 3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2). subsequence, such as shape(Y) = (2, 1) => (2).
For example: For example:
...@@ -155,7 +152,7 @@ For example: ...@@ -155,7 +152,7 @@ For example:
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information. The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the input $X$. But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
......
...@@ -23,57 +23,56 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -23,57 +23,56 @@ class AdamOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
// "Input(Param) of AdamOp should not be null."); "Input(Param) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
// "Input(Grad) of AdamOp should not be null."); "Input(Grad) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Moment1"), PADDLE_ENFORCE(ctx->HasInput("Moment1"),
// "Input(Moment1) of AdamOp should not be null."); "Input(Moment1) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Moment2"), PADDLE_ENFORCE(ctx->HasInput("Moment2"),
// "Input(Moment2) of AdamOp should not be null."); "Input(Moment2) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("LearningRate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
// "Input(LearningRate) of AdamOp should not be null."); "Input(LearningRate) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
// "Input(Beta1Pow) of AdamOp should not be null."); "Input(Beta1Pow) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
// "Input(Beta2Pow) of AdamOp should not be null."); "Input(Beta2Pow) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
// "Output(ParamOut) of AdamOp should not be null."); "Output(ParamOut) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
// "Output(Moment1Out) of AdamOp should not be null."); "Output(Moment1Out) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
// "Output(Moment2Out) of AdamOp should not be null."); "Output(Moment2Out) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate"); auto lr_dims = ctx->GetInputDim("LearningRate");
// PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
// "Learning rate should have 1 dimension"); "Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
// PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
// "Beta1 power accumulator should have 1 dimension"); "Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
// PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1, PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
// "Beta2 power accumulator should have 1 dimension"); "Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param"); auto param_dims = ctx->GetInputDim("Param");
// if (ctx->GetInputsVarType("Grad")[0] == if (ctx->GetInputsVarType("Grad")[0] ==
// framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
// PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Grad"), param_dims, ctx->GetInputDim("Grad"),
// "Param and Grad input of AdamOp should have same dimension"); "Param and Grad input of AdamOp should have same dimension");
// } }
// PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Moment1"), param_dims, ctx->GetInputDim("Moment1"),
// "Param and Moment1 input of AdamOp should have same dimension"); "Param and Moment1 input of AdamOp should have same dimension");
// PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Moment2"), param_dims, ctx->GetInputDim("Moment2"),
// "Param and Moment2 input of AdamOp should have same dimension"); "Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册