提交 4c256ca6 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_cudnn_lstm

......@@ -754,7 +754,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
if (!transfered_inplace_vars.empty()) {
if (run_by_executor_ && !transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
}
......@@ -776,6 +776,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
}
void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const {
......
......@@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel<T> {
const auto *label_data = label->data<int64_t>();
for (size_t i = 0; i < batch_size; i++) {
uint32_t binIdx = static_cast<uint32_t>(
inference_data[i * inference_width + 1] * num_thresholds);
auto predict_data = inference_data[i * inference_width + 1];
PADDLE_ENFORCE_LE(predict_data, 1,
"The predict data must less or equal 1.");
PADDLE_ENFORCE_GE(predict_data, 0,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
(*stat_pos)[binIdx] += 1.0;
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册