未验证 提交 ab8af5c4 编写于 作者: W Wilber 提交者: GitHub

[Fluid-Lite] Remove all PADDLE_ENFORCE and PADDLE_THROW. (#3785)

上级 c67b92b1
......@@ -15,8 +15,8 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH
#full api dynamic library
lite_cc_library(paddle_full_api_shared SHARED SRCS paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc
DEPS paddle_api paddle_api_light paddle_api_full)
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto)
target_link_libraries(paddle_full_api_shared framework_proto)
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry)
target_link_libraries(paddle_full_api_shared framework_proto op_registry)
if(LITE_WITH_X86)
add_dependencies(paddle_full_api_shared xxhash)
target_link_libraries(paddle_full_api_shared xxhash)
......
......@@ -13,18 +13,30 @@
// limitations under the License.
#include "lite/api/cxx_api.h"
#include <algorithm>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/paddle_use_passes.h"
#include "lite/utils/io.h"
namespace paddle {
namespace lite {
std::vector<std::string> GetAllOps() {
const std::map<std::string, std::string> &op2path =
OpKernelInfoCollector::Global().GetOp2PathDict();
std::vector<std::string> res;
for (const auto &op : op2path) {
res.push_back(op.first);
}
return res;
}
void Predictor::SaveModel(const std::string &dir,
lite_api::LiteModelType model_type,
bool record_info) {
......
......@@ -36,6 +36,8 @@ static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] =
".tailored_kernels_source_list";
static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list";
std::vector<std::string> GetAllOps();
/*
* Predictor for inference, input a model, it will optimize and execute it.
*/
......
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "lite/backends/x86/cupti_lib_path.h"
#include "lite/backends/x86/port.h"
#include "lite/backends/x86/warpctc_lib_path.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/env.h"
#include "lite/utils/paddle_enforce.h"
// DEFINE_string(cudnn_dir,
// "",
......@@ -178,7 +178,7 @@ auto error_msg =
#endif // !_WIN32
if (throw_on_error) {
CHECK(dso_handle != nullptr);
// PADDLE_ENFORCE(nullptr != dso_handle, error_msg, dlPath, errorno);
// CHECK(nullptr != dso_handle, error_msg, dlPath, errorno);
} else if (nullptr == dso_handle) {
// LOG(WARNING) << string::Sprintf(error_msg, dlPath, errorno);
}
......
......@@ -319,8 +319,8 @@ void BenchKernelSgd() {
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](
int n, const int64_t lower, const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
CHECK_LE(static_cast<size_t>(upper - lower), n - 1);
CHECK_GT(n, 0);
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
......
......@@ -129,11 +129,11 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const emb_seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.table_height, 0);
PADDLE_ENFORCE_GT(attr.table_width, 0);
PADDLE_ENFORCE_GT(attr.index_height, 0);
PADDLE_ENFORCE_GT(attr.index_width, 0);
PADDLE_ENFORCE_GT(attr.out_width, 0);
CHECK_GT(attr.table_height, 0);
CHECK_GT(attr.table_width, 0);
CHECK_GT(attr.index_height, 0);
CHECK_GT(attr.index_width, 0);
CHECK_GT(attr.out_width, 0);
return make_unique<EmbSeqPoolJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -17,7 +17,7 @@
#include <string>
#include "lite/backends/x86/jit/gen/jitcode.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
......
......@@ -27,7 +27,7 @@ void MatMulJitCode::genCode() {
preCode();
int block, rest;
const auto groups = packed_groups(n_, k_, &block, &rest);
PADDLE_ENFORCE_GT(groups.front(), 0);
CHECK_GT(groups.front(), 0);
const int block_len = sizeof(float) * block;
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
......@@ -116,9 +116,9 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const matmul_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.m, 0);
PADDLE_ENFORCE_GT(attr.n, 0);
PADDLE_ENFORCE_GT(attr.k, 0);
CHECK_GT(attr.m, 0);
CHECK_GT(attr.n, 0);
CHECK_GT(attr.k, 0);
return make_unique<MatMulJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "lite/backends/x86/jit/gen/jitcode.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -32,7 +32,7 @@ class MatMulJitCode : public JitCode {
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
PADDLE_ENFORCE_EQ(m_, 1, "Only support m==1 yet");
CHECK_EQ(m_, 1) << "Only support m==1 yet";
this->genCode();
}
......
......@@ -69,8 +69,8 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, 0);
PADDLE_ENFORCE_GT(attr.h, 0);
CHECK_GT(attr.w, 0);
CHECK_GT(attr.h, 0);
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -17,7 +17,7 @@
#include <string>
#include "lite/backends/x86/jit/gen/jitcode.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -125,8 +125,8 @@ class SeqPoolJitCode : public JitCode {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(
reg_idx, rest_used_num_regs, "All heights should use same regs");
CHECK_EQ(reg_idx, rest_used_num_regs)
<< "All heights should use same regs";
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
......
......@@ -17,7 +17,7 @@
#include <memory>
#include <vector>
#include "lite/backends/x86/jit/registry.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -113,9 +113,9 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const sgd_attr_t& attr) const override {
PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width);
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height);
PADDLE_ENFORCE_GE(attr.selected_rows_size, 0);
CHECK_EQ(attr.param_width, attr.grad_width);
CHECK_LE(attr.selected_rows_size, attr.grad_height);
CHECK_GE(attr.selected_rows_size, 0);
return make_unique<SgdJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -16,7 +16,7 @@
#include <memory>
#include <vector>
#include "lite/backends/x86/jit/registry.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -76,7 +76,7 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
PADDLE_ENFORCE_GT(w, 0);
CHECK_GT(w, 0);
return make_unique<VBroadcastJitCode>(w, CodeSize(w));
}
};
......
......@@ -21,8 +21,8 @@
// posix_memalign
#include "lite/backends/x86/cpu_info.h"
#include "lite/backends/x86/jit/macro.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/env.h"
#include "lite/utils/paddle_enforce.h"
#ifndef _WIN32
#define posix_memalign_free free
......@@ -62,12 +62,10 @@ void* GenBase::operator new(size_t size) {
#ifdef _WIN32
ptr = _aligned_malloc(size, alignment);
#else
PADDLE_ENFORCE_EQ(posix_memalign(&ptr, alignment, size),
0,
"GenBase Alloc %ld error!",
size);
CHECK_EQ(posix_memalign(&ptr, alignment, size), 0) << "GenBase Alloc " << size
<< " error!";
#endif
PADDLE_ENFORCE(ptr, "Fail to allocate GenBase CPU memory: size = %d .", size);
CHECK(ptr) << "Fail to allocate GenBase CPU memory: size = " << size;
return ptr;
}
......
......@@ -14,9 +14,10 @@
#include "lite/backends/x86/jit/helper.h"
#include <algorithm> // tolower
#include <cstring>
#include <numeric>
#include <string>
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -104,12 +105,12 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int block, rest;
const auto groups = packed_groups(n, k, &block, &rest);
std::for_each(groups.begin(), groups.end(), [&](int i) {
PADDLE_ENFORCE_GT(i, 0, "each element of groups should be larger than 0.");
CHECK_GT(i, 0) << "each element of groups should be larger than 0.";
});
int sum = std::accumulate(groups.begin(), groups.end(), 0);
std::memset(dst, 0, k * sum * block * sizeof(float));
PADDLE_ENFORCE_GE(
sum * block, n, "The packed n should be equal to or larger than n");
CHECK_GE(sum * block, n)
<< "The packed n should be equal to or larger than n";
const int block_len = sizeof(float) * block;
int n_offset = 0;
......
......@@ -23,7 +23,7 @@
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernel_key.h"
#include "lite/backends/x86/jit/kernel_pool.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -78,8 +78,8 @@ inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, lite::fluid::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
CHECK(ref_iter != ref_pool.end())
<< "Every Kernel should have reference function.";
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
......@@ -94,7 +94,7 @@ template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
CHECK(p) << "The Refer kernel should exsit";
return p->GetFunc();
}
......@@ -125,7 +125,7 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace.
auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
CHECK(ref != nullptr) << "Refer Kernel can not be empty.";
res.emplace_back(ref);
return res;
}
......@@ -140,11 +140,11 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
CHECK(i) << "jitcode kernel cast can not fail.";
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
CHECK(i) << "kernel cast can not fail.";
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
......@@ -166,7 +166,7 @@ template <typename KernelTuple, typename PlaceType = lite::fluid::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
CHECK_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
......
......@@ -14,7 +14,7 @@
#include "lite/backends/x86/jit/kernel_key.h"
#include <xxhash.h> // XXH64: 13.8 GB/s
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......
......@@ -18,7 +18,7 @@
#include <type_traits>
#include <vector>
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -104,11 +104,11 @@ void EmbSeqPool(const T* table,
const int64_t* idx,
T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
CHECK_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(
idx[i], attr->table_height, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
CHECK_LT(idx[i], attr->table_height) << "idx value: " << idx[i]
<< " i: " << i;
CHECK_GE(idx[i], 0) << "idx value: " << idx[i] << " i: " << i;
};
for (int64_t w = 0; w != attr->index_width; ++w) {
......@@ -175,22 +175,22 @@ void Sgd(const T* lr,
const int64_t* rows,
T* out,
const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
CHECK_EQ(attr->param_width, attr->grad_width);
CHECK_LE(attr->selected_rows_size, attr->grad_height);
T scalar = -lr[0];
int width = attr->grad_width;
if (out == param) {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
CHECK_LT(h_idx, attr->param_height);
CHECK_GE(h_idx, 0);
VAXPY(scalar, grad + i * width, out + h_idx * width, width);
}
} else {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
CHECK_LT(h_idx, attr->param_height);
CHECK_GE(h_idx, 0);
VScal(&scalar, grad + i * width, out + h_idx * width, width);
VAdd(param + h_idx * width,
out + h_idx * width,
......
......@@ -22,7 +22,6 @@
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/macro.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/paddle_enforce.h"
namespace paddle {
namespace lite {
......@@ -480,12 +479,12 @@ void EmbSeqPool(const T* table,
const int64_t* idx,
T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
CHECK_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(
idx[i], attr->table_height, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
CHECK_LT(idx[i], attr->table_height) << "idx value: " << idx[i]
<< " i: " << i;
CHECK_GE(idx[i], 0) << "idx value: " << idx[i] << " i: " << i;
};
for (int64_t w = 0; w != attr->index_width; ++w) {
......@@ -527,12 +526,12 @@ void Sgd(const T* lr,
const int64_t* rows,
T* out,
const lite::jit::sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
CHECK_EQ(attr->param_width, attr->grad_width);
CHECK_LE(attr->selected_rows_size, attr->grad_height);
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
CHECK_LT(h_idx, attr->param_height);
CHECK_GE(h_idx, 0);
for (int64_t j = 0; j < attr->grad_width; ++j) {
out[h_idx * attr->grad_width + j] =
param[h_idx * attr->grad_width + j] -
......
......@@ -910,8 +910,8 @@ void TestKernelSgd() {
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](
int n, const int64_t lower, const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
CHECK_LE(static_cast<size_t>(upper - lower), n - 1);
CHECK_GT(n, 0);
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
......
......@@ -116,7 +116,7 @@ class BeamSearchFunctor<TARGET(kX86), T> {
lod[0].assign(high_level.begin(), high_level.end());
lod[1].assign(low_level.begin(), low_level.end());
// if (!lite::fluid::CheckLoD(lod)) {
// //PADDLE_THROW("lod %s is not right", framework::LoDToString(lod));
// //LOG(FATAL)<<"lod %s is not right", framework::LoDToString(lod));
//}
selected_ids->set_lod(lod);
selected_scores->set_lod(lod);
......
......@@ -23,7 +23,7 @@ namespace math {
MatDescriptor CreateMatrixDescriptor(const lite::DDimLite &tensor_dim,
int num_flatten_cols,
bool trans) {
PADDLE_ENFORCE_GT(tensor_dim.size(), 1u);
CHECK_GT(tensor_dim.size(), 1u);
MatDescriptor retv;
if (num_flatten_cols > 1) {
auto flatten_dim = tensor_dim.Flatten2D(num_flatten_cols);
......
......@@ -287,22 +287,22 @@ struct CBlas<double> {
template <>
struct CBlas<lite::fluid::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
static void GEMM(...) { LOG(FATAL) << "float16 GEMM not supported on CPU"; }
static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
LOG(FATAL) << "float16 SMM_GEMM not supported on CPU";
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void VMUL(...) { LOG(FATAL) << "float16 VMUL not supported on CPU"; }
static void VEXP(...) { LOG(FATAL) << "float16 VEXP not supported on CPU"; }
static void VSQUARE(...) {
PADDLE_THROW("float16 VSQUARE not supported on CPU");
LOG(FATAL) << "float16 VSQUARE not supported on CPU";
}
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
static void VPOW(...) { LOG(FATAL) << "float16 VPOW not supported on CPU"; }
static void DOT(...) { LOG(FATAL) << "float16 DOT not supported on CPU"; };
static void SCAL(...) { LOG(FATAL) << "float16 SCAL not supported on CPU"; };
static void ASUM(...) { LOG(FATAL) << "float16 ASUM not supported on CPU"; };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
LOG(FATAL) << "float16 GEMM_BATCH not supported on CPU";
}
#endif
};
......@@ -461,11 +461,11 @@ void Blas<Target>::MatMul(const lite::Tensor &mat_a,
auto dim_a = mat_a.dims();
auto dim_b = mat_b.dims();
auto dim_out = mat_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
// PADDLE_ENFORCE(
// mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target(),
// "The targets of matrices must be same");
CHECK(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2)
<< "The input and output of matmul be matrix";
// CHECK(
// mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target())
// << "The targets of matrices must be same";
int M = dim_out[0];
int N = dim_out[1];
......@@ -746,7 +746,7 @@ void Blas<Target>::MatMul(const lite::Tensor &mat_a,
T alpha,
lite::Tensor *mat_out,
T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CHECK_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
......@@ -761,8 +761,8 @@ void Blas<Target>::MatMul(const lite::Tensor &mat_a,
beta,
mat_out->template mutable_data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
CHECK(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA,
transB,
......
......@@ -146,7 +146,7 @@ class ContextProjectFunctor {
}
}
if (padding_trainable) {
PADDLE_ENFORCE(padding_data != nullptr);
CHECK(padding_data != nullptr);
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
if (lod_level_0[i] == lod_level_0[i + 1]) continue;
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <functional>
#include <string>
#include "lite/backends/x86/cpu_info.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
#ifdef PADDLE_WITH_MKLML
#include "lite/backends/x86/mklml.h"
......@@ -652,7 +652,7 @@ class VecActivations {
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type: %s", type);
LOG(FATAL) << "Not support type: " << type;
}
};
......
......@@ -57,7 +57,7 @@ class CrossEntropyFunctor<lite::TargetType::kX86, T> {
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_remain; j++) {
int lbl = label_data[i * num_remain + j];
PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index);
CHECK((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index);
int index = i * num_classes + lbl * num_remain + j;
int loss_idx = i * num_remain + j;
loss_data[loss_idx] =
......
......@@ -27,7 +27,7 @@ namespace math {
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ENFORCE(static_cast<bool>(std::is_floating_point<T>::value));
CHECK(static_cast<bool>(std::is_floating_point<T>::value));
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <math.h>
#include <string>
#include "lite/backends/x86/cpu_info.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -46,8 +46,6 @@ inline ActivationType GetActivationType(const std::string &type) {
return ActivationType::kIdentity;
}
LOG(ERROR) << "Not support type " << type;
// PADDLE_ENFORCE(false, "Not support type %s", type);
// PADDLE_THROW("Not support type %s.", type);
return ActivationType();
}
......
......@@ -13,7 +13,7 @@ limitations under the License. */
#include "lite/backends/x86/math/detail/activation_functions.h"
#include "lite/core/context.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "lite/backends/x86/math/im2col.h"
#include <vector>
#include "lite/backends/x86/math/im2col_cfo_cpu.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -38,8 +38,8 @@ class Im2ColFunctor<lite::x86::math::ColFormat::kCFO,
const std::vector<int>& stride,
const std::vector<int>& padding,
lite::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col->dims().size() == 5);
CHECK_EQ(im.dims().size(), 3);
CHECK_EQ(col->dims().size(), 5);
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
dilation[1] == 1) {
......@@ -72,8 +72,8 @@ class Col2ImFunctor<lite::x86::math::ColFormat::kCFO,
const std::vector<int>& stride,
const std::vector<int>& padding,
lite::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
CHECK_EQ(im->dims().size(), 3);
CHECK_EQ(col.dims().size(), 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......@@ -82,20 +82,20 @@ class Col2ImFunctor<lite::x86::math::ColFormat::kCFO,
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
CHECK_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height)
<< "Output_height and padding(padding_up, padding_down) are "
"inconsistent.";
CHECK_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width)
<< "Output_height and padding(padding_up, padding_down) are "
"inconsistent.";
int channels_col = im_channels * filter_height * filter_width;
......@@ -150,8 +150,8 @@ class Im2ColFunctor<lite::x86::math::ColFormat::kOCF,
const std::vector<int>& stride,
const std::vector<int>& padding,
lite::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col->dims().size() == 5);
CHECK_EQ(im.dims().size(), 3);
CHECK_EQ(col->dims().size(), 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
......@@ -214,8 +214,8 @@ class Col2ImFunctor<lite::x86::math::ColFormat::kOCF,
const std::vector<int>& stride,
const std::vector<int>& padding,
lite::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
CHECK_EQ(im->dims().size(), 3);
CHECK_EQ(col.dims().size(), 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......@@ -224,16 +224,16 @@ class Col2ImFunctor<lite::x86::math::ColFormat::kOCF,
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
CHECK_EQ(
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
col_height)
<< "Output_height and padding(padding_up, padding_down) are "
"inconsistent.";
CHECK_EQ(
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
col_width)
<< "col_width and padding(padding_left, padding_right) are "
"inconsistent.";
T* im_data = im->template mutable_data<T>();
const T* col_data = col.data<T>();
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "lite/backends/x86/math/detail/activation_functions.h"
#include "lite/core/context.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......
......@@ -121,8 +121,8 @@ struct RowwiseAdd<lite::TargetType::kX86, T> {
lite::Tensor* output) {
const auto& in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
CHECK_EQ(vector.numel(), size);
CHECK_EQ(output->dims(), in_dims);
const T* input_data = input.data<T>();
const T* vector_data = vector.data<T>();
......
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
#include "lite/fluid/float16.h"
#include "lite/utils/paddle_enforce.h"
//#include "lite/tensor_util.h"
#include "lite/utils/cp_logging.h"
// #include "lite/tensor_util.h"
namespace paddle {
namespace lite {
......
......@@ -59,7 +59,7 @@ void ColwiseSum<Target, T>::operator()(const lite::Context<Target>& context,
lite::TensorLite* out) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(out->numel(), size);
CHECK_EQ(out->numel(), size);
auto in = lite::fluid::EigenMatrix<T>::From(input);
auto vec = lite::fluid::EigenVector<T>::Flatten(*out);
......@@ -81,7 +81,7 @@ class ColwiseSum<lite::TargetType::kX86, T> {
auto& in_dims = input.dims();
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
CHECK_EQ(out->numel(), size);
T* out_buf = out->template mutable_data<T>(out->target());
const T* in_buf = input.data<T>();
......@@ -103,8 +103,8 @@ void RowwiseMean<Target, T>::operator()(const lite::Context<Target>& context,
const lite::TensorLite& input,
lite::TensorLite* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
CHECK_EQ(in_dims.size(), 2U);
CHECK_EQ(out->numel(), in_dims[0]);
auto in = lite::fluid::EigenMatrix<T>::From(input);
auto vec = lite::fluid::EigenVector<T>::Flatten(*out);
......@@ -124,10 +124,10 @@ class RowwiseMean<lite::TargetType::kX86, T> {
const lite::TensorLite& input,
lite::TensorLite* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
CHECK_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), height);
CHECK_EQ(out->numel(), height);
auto inv_size = 1.0 / size;
T* out_buf = out->template mutable_data<T>(out->target());
const T* in_buf = input.data<T>();
......@@ -147,8 +147,8 @@ void RowwiseSum<Target, T>::operator()(const lite::Context<Target>& context,
const lite::TensorLite& input,
lite::TensorLite* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
CHECK_EQ(in_dims.size(), 2U);
CHECK_EQ(out->numel(), in_dims[0]);
auto in = lite::fluid::EigenMatrix<T>::From(input);
auto vec = lite::fluid::EigenVector<T>::Flatten(*out);
......@@ -168,10 +168,10 @@ class RowwiseSum<lite::TargetType::kX86, T> {
const lite::TensorLite& input,
lite::TensorLite* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
CHECK_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), height);
CHECK_EQ(out->numel(), height);
T* out_buf = out->template mutable_data<T>(out->target());
const T* in_buf = input.data<T>();
......
......@@ -273,7 +273,7 @@ TEST(math_funciton, set_constant) {
auto* ctx = new paddle::platform::CPUDeviceContext();
paddle::operators::math::set_constant(*ctx, &t, 10);
for (int64_t i = 0; i < t.numel(); ++i) {
PADDLE_ENFORCE_EQ(10, t.data<int>()[i]);
CHECK_EQ(10, t.data<int>()[i]);
}
delete ctx;
}
......
......@@ -32,7 +32,7 @@ namespace math {
class Sampler {
public:
explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
// PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0.");
// CHECK_GT(range, 0, "Range should be greater than 0.");
if (seed == 0) {
std::random_device r;
seed_ = r();
......
......@@ -31,7 +31,7 @@ struct SelectedRowsAdd<lite::TargetType::kX86, T> {
const fluid::SelectedRows& input2,
fluid::SelectedRows* output) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height());
CHECK_EQ(in1_height, input2.height());
output->set_height(in1_height);
auto& in1_rows = input1.rows();
......@@ -49,8 +49,8 @@ struct SelectedRowsAdd<lite::TargetType::kX86, T> {
auto& in2_value = input2.value();
auto in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
CHECK_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
CHECK_EQ(in1_row_numel, out_value->numel() / out_rows.size());
auto* out_data = out_value->template mutable_data<T>();
auto* in1_data = in1_value.data<T>();
......@@ -73,15 +73,15 @@ struct SelectedRowsAddTensor<lite::TargetType::kX86, T> {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
CHECK_EQ(in1_height, in2_dims[0]);
CHECK_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
CHECK_EQ(in1_row_numel, input2.numel() / in1_height);
CHECK_EQ(in1_row_numel, output->numel() / in1_height);
SetConstant<lite::TargetType::kX86, T> functor;
functor(context, output, 0.0);
......@@ -113,7 +113,7 @@ struct SelectedRowsAddTo<lite::TargetType::kX86, T> {
const int64_t input2_offset,
fluid::SelectedRows* input2) {
auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
CHECK_EQ(in1_height, input2->height());
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
......@@ -149,7 +149,7 @@ struct SelectedRowsSumTo<lite::TargetType::kX86, T> {
auto& in_rows = (*iter)->rows();
size += in_rows.end() - in_rows.begin();
auto in1_height = (*iter)->height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
CHECK_EQ(in1_height, input2->height());
}
// concat rows
std::vector<int64_t> in2_rows;
......@@ -185,13 +185,13 @@ struct SelectedRowsAddToTensor<lite::TargetType::kX86, T> {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
CHECK_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
CHECK_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->template mutable_data<T>();
......@@ -291,12 +291,11 @@ struct MergeAdd<lite::TargetType::kX86, T> {
if (input->rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(input_width,
input->value().dims()[1],
"all input should have same "
"dimension except for the first one");
PADDLE_ENFORCE_EQ(
input_height, input->height(), "all input should have same height");
CHECK_EQ(input_width, input->value().dims()[1])
<< "all input should have same "
"dimension except for the first one";
CHECK_EQ(input_height, input->height())
<< "all input should have same height";
row_num += input->rows().size();
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
......@@ -376,13 +375,13 @@ struct UpdateToTensor<lite::TargetType::kX86, T> {
lite::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
CHECK_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
CHECK_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->template data<T>();
......
......@@ -30,12 +30,10 @@ class CopyMatrixRowsFunctor<lite::TargetType::kX86, T> {
const uint64_t* index = index_lod.data();
const auto& src_dims = src.dims();
const auto& dst_dims = dst->dims();
PADDLE_ENFORCE_EQ(
src_dims.size(), 2UL, "The src must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(
dst_dims.size(), 2UL, "The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(
src_dims[1], dst_dims[1], "The width of src and dst must be same.");
CHECK_EQ(src_dims.size(), 2UL) << "The src must be matrix with rank 2.";
CHECK_EQ(dst_dims.size(), 2UL) << "The dst must be matrix with rank 2.";
CHECK_EQ(src_dims[1], dst_dims[1])
<< "The width of src and dst must be same.";
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/fluid/eigen.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -66,21 +66,18 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
PADDLE_ENFORCE_GT(lods.size(),
2UL,
"The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.");
PADDLE_ENFORCE_EQ(
lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]),
"The LoD information should be consistent with the dims.");
CHECK_GT(lods.size(), 2UL)
<< "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.";
CHECK_EQ(lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0]))
<< "The LoD information should be consistent with the dims.";
CopyMatrixRowsFunctor<Target, T> to_batch;
to_batch(context, lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
......@@ -165,14 +162,11 @@ class Batch2LoDTensorFunctor {
const lite::Tensor& batch,
lite::Tensor* lod_tensor) const {
auto in_lod = batch.lod();
PADDLE_ENFORCE_GT(in_lod.size(),
2UL,
"The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.");
PADDLE_ENFORCE_EQ(
in_lod[1].size(),
static_cast<size_t>(lod_tensor->dims()[0]),
"The LoD information should be consistent with the dims.");
CHECK_GT(in_lod.size(), 2UL)
<< "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.";
CHECK_EQ(in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]))
<< "The LoD information should be consistent with the dims.";
CopyMatrixRowsFunctor<Target, T> to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false);
}
......
......@@ -37,10 +37,9 @@ void CopyValidData(lite::Tensor* dst_tensor,
layout == kBatchLengthWidth ? step_width : seq_num * step_width;
for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) {
int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
PADDLE_ENFORCE_GE(
pad_seq_len,
valid_seq_len,
"The padded sequence length can not be less than its original length.");
CHECK_GE(pad_seq_len, valid_seq_len) << "The padded sequence length can "
"not be less than its original "
"length.";
int seq_data_offset = seq_offsets[seq_idx] * step_width;
int pad_data_offset = layout == kBatchLengthWidth
? seq_idx * pad_seq_len * step_width
......@@ -108,9 +107,9 @@ class PaddingLoDTensorFunctor<lite::TargetType::kX86, T> {
pad_seq_len,
step_width,
layout);
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'.");
CHECK(pad_value.numel() == 1 || pad_value.numel() == step_width)
<< "The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'.";
// fill padding value
T* pad_data = pad_tensor->template mutable_data<T>();
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/fluid/lod.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -46,15 +46,14 @@ inline static void CheckDims(const lite::DDim& seq_tensor_dims,
int64_t padded_seq_len,
int64_t step_width,
const PadLayout& layout) {
PADDLE_ENFORCE_EQ(static_cast<size_t>(seq_tensor_dims[0]),
seq_offset.back(),
"Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences.");
CHECK_EQ(static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back())
<< "Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences.";
PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
seq_tensor_dims.size() == pad_tensor_dims.size(),
"pad_tensor's rank should be 1 greater than seq_tensor's "
"rank, or be equal with it.");
CHECK(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
seq_tensor_dims.size() == pad_tensor_dims.size())
<< "pad_tensor's rank should be 1 greater than seq_tensor's "
"rank, or be equal with it.";
}
/*
......
......@@ -46,12 +46,12 @@ class MaxSeqPoolFunctor {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1u);
PADDLE_ENFORCE_GT(out_dims.size(), 1u);
CHECK_GT(in_dims.size(), 1u);
CHECK_GT(out_dims.size(), 1u);
for (size_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
CHECK_EQ(in_dims[i], out_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, out_dims);
CHECK_EQ(idx_dims, out_dims);
auto starts = input.lod()[0];
const T* in_data = input.data<T>();
......@@ -95,10 +95,10 @@ class MaxSeqPoolFunctor<T, true> {
lite::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1u);
PADDLE_ENFORCE_GT(out_dims.size(), 1u);
CHECK_GT(in_dims.size(), 1u);
CHECK_GT(out_dims.size(), 1u);
for (size_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
CHECK_EQ(in_dims[i], out_dims[i]);
}
auto starts = input.lod()[0];
......@@ -136,12 +136,12 @@ class MaxSeqPoolGradFunctor {
auto og_dims = out_grad.dims();
auto ig_dims = in_grad->dims();
auto idx_dims = index.dims();
PADDLE_ENFORCE_GT(og_dims.size(), 1);
PADDLE_ENFORCE_GT(ig_dims.size(), 1);
CHECK_GT(og_dims.size(), 1);
CHECK_GT(ig_dims.size(), 1);
for (size_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
CHECK_EQ(og_dims[i], ig_dims[i]);
}
PADDLE_ENFORCE_EQ(idx_dims, og_dims);
CHECK_EQ(idx_dims, og_dims);
const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>();
......@@ -236,7 +236,7 @@ class SumSeqPoolGradFunctor {
auto lod = in_grad->lod()[0];
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
CHECK(in_w == out_w);
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->template mutable_data<T>(TARGET(kX86));
auto blas = math::GetBlas<TARGET(kX86), T>(context);
......@@ -330,7 +330,7 @@ class SequencePoolFunctor<TARGET(kX86), T> {
out_e.device(eigen_device) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else {
PADDLE_THROW("unsupported pooling pooltype");
LOG(FATAL) << "unsupported pooling pooltype";
}
}
}
......@@ -389,7 +389,7 @@ class SequencePoolGradFunctor<TARGET(kX86), T> {
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(eigen_device) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
LOG(FATAL) << "unsupported pooling pooltype";
}
}
}
......
......@@ -50,9 +50,9 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
in_grad.mutable_data<T>(in_dims, context->GetPlace());
// check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
CHECK_EQ(in_grad.dims().size(), out_grad.dims().size());
for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]);
CHECK_EQ(in_grad.dims()[i], out_grad.dims()[i]);
}
// call functor
......
......@@ -55,7 +55,7 @@ void Tree2ColUtil::construct_tree(const lite::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count) {
auto edge_set_dims = EdgeSet.dims();
PADDLE_ENFORCE_EQ(edge_set_dims[1], 2);
CHECK_EQ(edge_set_dims[1], 2);
int64_t edge_count = EdgeSet.numel();
const int *edge_data = EdgeSet.data<int>();
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/x86/math/unpooling.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -41,7 +41,7 @@ class Unpool2dMaxFunctor<lite::TargetType::kX86, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
CHECK(index < output_feasize) << "err index in unpooling!";
output_data[index] = input_data[i];
}
input_data += input_feasize;
......@@ -77,7 +77,7 @@ class Unpool2dMaxGradFunctor<lite::TargetType::kX86, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
CHECK(index < output_feasize) << "err index in unpooling!";
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "lite/backends/x86/math/vol2col.h"
#include <vector>
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -36,8 +36,8 @@ class Vol2ColFunctor<lite::TargetType::kX86, T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
lite::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col->dims().size() == 7);
CHECK_EQ(vol.dims().size(), 4);
CHECK_EQ(col->dims().size(), 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
......@@ -52,27 +52,27 @@ class Vol2ColFunctor<lite::TargetType::kX86, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"mismatching.");
CHECK_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth)
<< "input_depth and output_depth are "
"mismatching.";
CHECK_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height)
<< "input_height and output_height are "
"mismatching.";
CHECK_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width)
<< "input_width and output_width are "
"mismatching.";
const T* vol_data = vol.data<T>();
T* col_data = col->template mutable_data<T>();
......@@ -122,8 +122,8 @@ class Col2VolFunctor<lite::TargetType::kX86, T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
lite::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
CHECK_EQ(vol->dims().size(), 4);
CHECK_EQ(col.dims().size(), 7);
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
......@@ -138,27 +138,27 @@ class Col2VolFunctor<lite::TargetType::kX86, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"mismatching.");
CHECK_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth)
<< "input_depth and output_depth are "
"mismatching.";
CHECK_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height)
<< "input_height and output_height are "
"mismatching.";
CHECK_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width)
<< "input_width and output_width are "
"mismatching.";
T* vol_data = vol->template mutable_data<T>();
const T* col_data = col.data<T>();
......
......@@ -157,8 +157,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
///////////////////////////////////////////////////////////////////////////////
if (enable_int8) {
std::string weight_name = conv_op_desc->Input("Filter").front();
PADDLE_ENFORCE(conv_op_desc->HasInputScale(weight_name),
"INT8 mode: Conv should has weight_scale attr");
CHECK(conv_op_desc->HasInputScale(weight_name))
<< "INT8 mode: Conv should has weight_scale attr";
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8
auto weight_scale =
......
......@@ -18,7 +18,7 @@
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......
......@@ -67,7 +67,7 @@ framework::proto::VarType::Type ToDataType(std::type_index type) {
if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
LOG(FATAL) << "Not support " << type.name() << " as tensor type";
return static_cast<framework::proto::VarType::Type>(-1);
}
......@@ -76,8 +76,8 @@ std::type_index ToTypeIndex(framework::proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second;
}
PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type",
static_cast<int>(type));
LOG(FATAL) << "Not support framework::proto::VarType::Type("
<< static_cast<int>(type) << ") as tensor type";
return std::type_index(typeid(void));
}
......@@ -86,8 +86,8 @@ std::string DataTypeToString(const framework::proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second;
}
PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type",
static_cast<int>(type));
LOG(FATAL) << "Not support framework::proto::VarType::Type("
<< static_cast<int>(type) << ") as tensor type";
return std::string();
}
......@@ -96,7 +96,8 @@ size_t SizeOfType(framework::proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type).c_str());
LOG(FATAL) << "Not support " << DataTypeToString(type).c_str()
<< " as tensor type";
return 0;
}
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <typeindex>
#include "lite/core/framework.pb.h"
#include "lite/fluid/float16.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -72,7 +72,7 @@ inline void VisitDataType(framework::proto::VarType::Type type,
_ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type);
LOG(FATAL) << "Not supported " << type;
}
extern std::string DataTypeToString(const framework::proto::VarType::Type type);
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector>
#include "lite/core/tensor.h"
#include "lite/fluid/float16.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
......@@ -30,7 +30,7 @@ struct EigenDim {
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const lite::DDim& dims) {
PADDLE_ENFORCE_EQ(dims.size(), D, "D must match DDim::size");
CHECK_EQ(dims.size(), D) << "D must match DDim::size";
Type ret;
for (size_t d = 0; d < dims.size(); d++) {
ret[d] = dims[d];
......@@ -39,7 +39,7 @@ struct EigenDim {
}
static Type From(const DDim::value_type length) {
PADDLE_ENFORCE_EQ(D, 1, "D must be 1.");
CHECK_EQ(D, 1) << "D must be 1.";
Type ret;
ret[0] = length;
return ret;
......@@ -84,16 +84,16 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT
int num_col_dims) {
int rank = tensor.dims().size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
"`num_col_dims` must be between (0, rank_of_tensor).");
CHECK(num_col_dims > 0 && num_col_dims < rank)
<< "`num_col_dims` must be between (0, rank_of_tensor).";
return EigenMatrix::From(tensor, tensor.dims().Flatten2D(num_col_dims));
}
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_col_dims) {
int rank = tensor.dims().size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
"`num_col_dims` must be between (0, rank_of_tensor).");
CHECK(num_col_dims > 0 && num_col_dims < rank)
<< "`num_col_dims` must be between (0, rank_of_tensor).";
return EigenMatrix::From(tensor, tensor.dims().Flatten2D(num_col_dims));
}
};
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include <mutex> // NOLINT
#endif // !_WIN32
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -33,17 +33,15 @@ struct RWLock {
~RWLock() { pthread_rwlock_destroy(&lock_); }
inline void RDLock() {
PADDLE_ENFORCE_EQ(
pthread_rwlock_rdlock(&lock_), 0, "acquire read lock failed");
CHECK_EQ(pthread_rwlock_rdlock(&lock_), 0) << "acquire read lock failed";
}
inline void WRLock() {
PADDLE_ENFORCE_EQ(
pthread_rwlock_wrlock(&lock_), 0, "acquire write lock failed");
CHECK_EQ(pthread_rwlock_wrlock(&lock_), 0) << "acquire write lock failed";
}
inline void UNLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
CHECK_EQ(pthread_rwlock_unlock(&lock_), 0) << "unlock failed";
}
private:
......
......@@ -119,7 +119,7 @@ void DeserializeFromStream(
// the 1st field, unit32_t version for SelectedRows
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
CHECK_EQ(version, 0U) << "Only version 0 is supported";
}
{
// the 2st field, rows information
......@@ -163,24 +163,22 @@ int64_t SelectedRows::AutoGrownIndex(int64_t key,
if (iter == id_to_index_.end()) {
rwlock_->UNLock();
if (!auto_grown) {
PADDLE_THROW("key %ld not found", key);
LOG(FATAL) << "key " << key << " not found";
}
rwlock_->WRLock();
auto map_size = id_to_index_.size();
auto vector_size = rows_.size();
if (map_size != vector_size) {
rwlock_->UNLock();
PADDLE_THROW(
"id_to_index_ size %lu should have the same size with rows_ %lu",
map_size,
vector_size);
LOG(FATAL) << "id_to_index_ size " << map_size
<< " should have the same size with rows_ " << vector_size;
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
int row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
LOG(FATAL) << "selected rows is full, then length exceed " << row_num;
}
// key logic to put a key into id_to_index_
rows_.push_back(key);
......@@ -213,16 +211,14 @@ void SelectedRows::Get(const lite::Tensor& ids,
lite::Tensor* value,
bool auto_grown,
bool is_test) {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
CHECK(value->IsInitialized()) << "The value tensor should be initialized.";
if (ids.numel() == 0) {
VLOG(3) << "keys is empty, please check data!";
} else {
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width,
value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"except the dims[0].");
CHECK_EQ(value_width, value->numel() / value->dims()[0])
<< "output tensor should have the same shape with table "
"except the dims[0].";
for (int i = 0; i < ids.numel(); ++i) {
auto id = ids.data<int64_t>()[i];
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
......
......@@ -82,7 +82,7 @@ class SelectedRows {
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
PADDLE_THROW("id %ld not in table", key);
LOG(FATAL) << "id " << key << " not in table";
}
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
......
......@@ -22,7 +22,6 @@ limitations under the License. */
#include "lite/fluid/for_range.h"
#include "lite/fluid/transform.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/paddle_enforce.h"
#include "lite/utils/variant.h"
namespace paddle {
......@@ -66,9 +65,8 @@ inline void get_mid_dims(const lite::DDim &x_dims,
for (size_t i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ(
*mid_flag, 0, "Broadcast support y_dims with single 1.");
PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch.");
CHECK_EQ(*mid_flag, 0) << "Broadcast support y_dims with single 1.";
CHECK_EQ(y_dims[i], 1) << "Broadcast dimension mismatch.";
// m*n*k m*1*k
for (size_t j = 0; j < i; ++j) {
(*pre) *= y_dims[j];
......@@ -95,8 +93,7 @@ inline void get_mid_dims(const lite::DDim &x_dims,
}
for (size_t i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i + axis], y_dims[i], "Broadcast dimension mismatch.");
CHECK_EQ(x_dims[i + axis], y_dims[i]) << "Broadcast dimension mismatch.";
(*n) *= y_dims[i];
}
......@@ -314,17 +311,16 @@ void ElementwiseComputeEx(const lite::Context<Target> &ctx,
TransformFunctor<Functor, T, Target, OutType> functor(x, y, z, ctx, func);
auto x_dims = x->dims();
auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(),
y_dims_untrimed.size(),
"Rank of first input must >= rank of second input.");
CHECK_GE(x_dims.size(), y_dims_untrimed.size())
<< "Rank of first input must >= rank of second input.";
if (x_dims == y_dims_untrimed) {
functor.Run();
return;
}
axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < static_cast<int>(x_dims.size()),
"Axis should be in range [0, x_dims)");
CHECK(axis >= 0 && axis < static_cast<int>(x_dims.size()))
<< "Axis should be in range [0, x_dims)";
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post, mid_flag = 0;
......@@ -560,9 +556,8 @@ void FusedElemwiseAndActComputeEx(const lite::Context<Target> &ctx,
lite::Tensor *out,
lite::Tensor *intermediate_out) {
if (KeepIntermediateOut) {
PADDLE_ENFORCE(intermediate_out,
"The save_intermediate_out is opened, "
"intermediate_out should not be nullptr.");
CHECK(intermediate_out) << "The save_intermediate_out is opened, "
"intermediate_out should not be nullptr.";
}
const lite::DDim &x_dim = x.dims();
......
......@@ -63,10 +63,10 @@ class LayerNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
out.ShareDataWith(*y);
out.Resize(matrix_shape);
PADDLE_ENFORCE_EQ(Mean->numel(), left);
PADDLE_ENFORCE_EQ(Var->numel(), left);
PADDLE_ENFORCE_EQ(Scale->numel(), right);
PADDLE_ENFORCE_EQ(Bias->numel(), right);
CHECK_EQ(Mean->numel(), left);
CHECK_EQ(Var->numel(), left);
CHECK_EQ(Scale->numel(), right);
CHECK_EQ(Bias->numel(), right);
auto ker = paddle::lite::jit::KernelFuncs<jit::LayerNormTuple<T>,
lite::fluid::CPUPlace>::Cache()
......
......@@ -41,8 +41,8 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto *param_out = &sgd_param.ParamOut->raw_tensor();
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz);
CHECK_EQ(param->numel(), sz);
CHECK_EQ(grad->numel(), sz);
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->template data<T>();
......
......@@ -60,7 +60,7 @@ inline void TransCompute(const int dim,
trans6(context, in, out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
LOG(FATAL) << "Tensors with rank at most 6 are supported";
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines PADDLE_ENFORCE_xx, which helps to adapt the legacy fluid
* codes.
*/
#pragma once
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
#define PADDLE_ENFORCE(cond, ...) \
CHECK((cond)) << paddle::lite::string_format("" __VA_ARGS__);
#define PADDLE_ENFORCE_EQ(a, b, ...) \
CHECK_EQ((a), (b)) << paddle::lite::string_format("" __VA_ARGS__);
#define PADDLE_ENFORCE_LE(a, b, ...) \
CHECK_LE((a), (b)) << paddle::lite::string_format("" __VA_ARGS__);
#define PADDLE_ENFORCE_LT(a, b, ...) \
CHECK_LT((a), (b)) << paddle::lite::string_format("" __VA_ARGS__);
#define PADDLE_ENFORCE_GE(a, b, ...) \
CHECK_GE((a), (b)) << paddle::lite::string_format("" __VA_ARGS__);
#define PADDLE_ENFORCE_GT(a, b, ...) \
CHECK_GT((a), (b)) << paddle::lite::string_format("" __VA_ARGS__);
#ifndef PADDLE_THROW
#define PADDLE_THROW(...) printf("" __VA_ARGS__);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册