提交 60ff05e3 编写于 作者: T tensor-tang

Merge branch 'luotao1-fix_rnn2_test' into fix/jit/exp

test=develop
...@@ -18,12 +18,12 @@ namespace paddle { ...@@ -18,12 +18,12 @@ namespace paddle {
namespace inference { namespace inference {
using namespace framework; // NOLINT using namespace framework; // NOLINT
static std::vector<float> result_data;
struct DataRecord { struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all; std::vector<std::vector<std::vector<float>>> link_step_data_all;
std::vector<size_t> lod; std::vector<size_t> lod;
std::vector<std::vector<float>> rnn_link_data; std::vector<std::vector<float>> rnn_link_data;
std::vector<float> result_data;
size_t num_samples; // total number of samples size_t num_samples; // total number of samples
size_t batch_iter{0}; size_t batch_iter{0};
size_t batch_size{1}; size_t batch_size{1};
...@@ -57,6 +57,7 @@ struct DataRecord { ...@@ -57,6 +57,7 @@ struct DataRecord {
std::ifstream file(path); std::ifstream file(path);
std::string line; std::string line;
int num_lines = 0; int num_lines = 0;
result_data.clear();
while (std::getline(file, line)) { while (std::getline(file, line)) {
num_lines++; num_lines++;
std::vector<std::string> data; std::vector<std::string> data;
...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) { ...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result // the first inference result
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PADDLE_ENFORCE_GT(outputs.size(), 0); PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]); size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data()); float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], data.result_data[i], 1e-3); EXPECT_NEAR(result[i], result_data[i], 1e-3);
} }
} }
} }
......
...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include <sys/time.h> #include <sys/time.h>
#include <limits> #include <limits>
#include "glog/logging.h" // For VLOG #include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h" #include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -336,8 +334,11 @@ void GRPCClient::Proceed() { ...@@ -336,8 +334,11 @@ void GRPCClient::Proceed() {
VLOG(3) << c->GetVarHandlePtr()->String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
// FIXME(gongwb): parse error_details?
LOG(ERROR) << c->GetVarHandlePtr()->String() LOG(ERROR) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details();
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false; ok_ = false;
...@@ -345,7 +346,10 @@ void GRPCClient::Proceed() { ...@@ -345,7 +346,10 @@ void GRPCClient::Proceed() {
c->Finish(false); c->Finish(false);
} else { } else {
LOG(FATAL) << c->GetVarHandlePtr()->String() LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details();
c->Finish(false); c->Finish(false);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册