未验证 提交 cf1a8f68 编写于 作者: T Thunderbrook 提交者: GitHub

cherry-pick try catch (#26880)

cherry-pick fix cvm check
test=develop
Co-authored-by: N123malin <malin10@baidu.com>
上级 fdd24939
......@@ -24,7 +24,7 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
}
template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
......@@ -37,8 +37,7 @@ std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
return os.str();
}
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
int64_t end) {
std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
......@@ -51,7 +50,7 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
return os.str();
}
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
......
......@@ -45,7 +45,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end);
std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
......@@ -148,6 +148,7 @@ class DeviceWorker {
FetchConfig fetch_config_;
bool use_cvm_;
bool no_cvm_;
std::vector<std::string> all_param_;
};
class CPUWorkerBase : public DeviceWorker {
......
......@@ -807,7 +807,50 @@ void DownpourWorker::TrainFiles() {
}
}
if (!need_skip) {
#ifdef PADDLE_WITH_PSLIB
try {
op->Run(*thread_scope_, place_);
} catch (std::exception& e) {
fprintf(stderr, "error message: %s\n", e.what());
auto& ins_id_vec = device_reader_->GetInsIdVec();
size_t batch_size = device_reader_->GetCurBatchSize();
std::string s = "";
for (auto& ins_id : ins_id_vec) {
if (s != "") s += ",";
s += ins_id;
}
fprintf(stderr, "batch_size: %zu, ins_ids_vec: %s\n", batch_size,
s.c_str());
s = "";
for (auto& param : all_param_) {
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
Tensor* tensor = nullptr;
int64_t len = 0;
if (var->IsType<framework::LoDTensor>()) {
tensor = var->GetMutable<LoDTensor>();
len = tensor->numel();
} else if (var->IsType<SelectedRows>()) {
auto selected_rows = var->GetMutable<SelectedRows>();
tensor = selected_rows->mutable_value();
len = tensor->numel();
}
if (!tensor->IsInitialized()) {
continue;
}
s += param + ":" + std::to_string(len) + ":";
s += PrintLodTensor(tensor, 0, len);
fprintf(stderr, "%s\n", s.c_str());
fflush(stderr);
s = "";
}
throw e;
}
#else
op->Run(*thread_scope_, place_);
#endif
}
}
......
......@@ -72,6 +72,7 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
thread_scope_ = &root_scope_->NewScope();
for (auto &var : block.AllVars()) {
all_param_.push_back(var->Name());
if (var->Persistable()) {
auto *ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
......
......@@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM");
OP_INOUT_CHECK(ctx->HasInput("CVM"), "Input", "CVM", "CVM");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(X)'s rank should be 2."));
PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2UL,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
"The 2nd dimension of "
"Input(CVM) should be 2."));
if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册