未验证 提交 cf1a8f68 编写于 作者: T Thunderbrook 提交者: GitHub

cherry-pick try catch (#26880)

cherry-pick fix cvm check
test=develop
Co-authored-by: N123malin <malin10@baidu.com>
上级 fdd24939
...@@ -24,7 +24,7 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) { ...@@ -24,7 +24,7 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
} }
template <typename T> template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) { std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel(); auto count = tensor->numel();
if (start < 0 || end > count) { if (start < 0 || end > count) {
VLOG(3) << "access violation"; VLOG(3) << "access violation";
...@@ -37,8 +37,7 @@ std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) { ...@@ -37,8 +37,7 @@ std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
return os.str(); return os.str();
} }
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end) {
int64_t end) {
auto count = tensor->numel(); auto count = tensor->numel();
if (start < 0 || end > count) { if (start < 0 || end > count) {
VLOG(3) << "access violation"; VLOG(3) << "access violation";
...@@ -51,7 +50,7 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, ...@@ -51,7 +50,7 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
return os.str(); return os.str();
} }
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) { std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end) {
std::string out_val; std::string out_val;
if (tensor->type() == proto::VarType::FP32) { if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end); out_val = PrintLodTensorType<float>(tensor, start, end);
......
...@@ -45,7 +45,7 @@ limitations under the License. */ ...@@ -45,7 +45,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end); std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index); std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
...@@ -148,6 +148,7 @@ class DeviceWorker { ...@@ -148,6 +148,7 @@ class DeviceWorker {
FetchConfig fetch_config_; FetchConfig fetch_config_;
bool use_cvm_; bool use_cvm_;
bool no_cvm_; bool no_cvm_;
std::vector<std::string> all_param_;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
......
...@@ -807,7 +807,50 @@ void DownpourWorker::TrainFiles() { ...@@ -807,7 +807,50 @@ void DownpourWorker::TrainFiles() {
} }
} }
if (!need_skip) { if (!need_skip) {
#ifdef PADDLE_WITH_PSLIB
try {
op->Run(*thread_scope_, place_);
} catch (std::exception& e) {
fprintf(stderr, "error message: %s\n", e.what());
auto& ins_id_vec = device_reader_->GetInsIdVec();
size_t batch_size = device_reader_->GetCurBatchSize();
std::string s = "";
for (auto& ins_id : ins_id_vec) {
if (s != "") s += ",";
s += ins_id;
}
fprintf(stderr, "batch_size: %zu, ins_ids_vec: %s\n", batch_size,
s.c_str());
s = "";
for (auto& param : all_param_) {
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
Tensor* tensor = nullptr;
int64_t len = 0;
if (var->IsType<framework::LoDTensor>()) {
tensor = var->GetMutable<LoDTensor>();
len = tensor->numel();
} else if (var->IsType<SelectedRows>()) {
auto selected_rows = var->GetMutable<SelectedRows>();
tensor = selected_rows->mutable_value();
len = tensor->numel();
}
if (!tensor->IsInitialized()) {
continue;
}
s += param + ":" + std::to_string(len) + ":";
s += PrintLodTensor(tensor, 0, len);
fprintf(stderr, "%s\n", s.c_str());
fflush(stderr);
s = "";
}
throw e;
}
#else
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
#endif
} }
} }
......
...@@ -72,6 +72,7 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) { ...@@ -72,6 +72,7 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto &var : block.AllVars()) { for (auto &var : block.AllVars()) {
all_param_.push_back(var->Name());
if (var->Persistable()) { if (var->Persistable()) {
auto *ptr = root_scope_->Var(var->Name()); auto *ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
......
...@@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel { ...@@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM");
OP_INOUT_CHECK(ctx->HasInput("CVM"), "Input", "CVM", "CVM");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(X)'s rank should be 2.")); "Input(X)'s rank should be 2."));
PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2UL,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
"The 2nd dimension of "
"Input(CVM) should be 2."));
if (ctx->Attrs().Get<bool>("use_cvm")) { if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]}); ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册