提交 2af35002 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix some enforce (#3301)

* fix some enforce

* remove compatible_type to avoid compile error

* remove shared_ptr

* fix tensor error msg
上级 95fb1617
...@@ -34,7 +34,7 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -34,7 +34,7 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, PADDLE_ENFORCE_NOT_NULL(in_out_idxs_,
"Input Output Indices could not be nullptr"); "Input Output Indices could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
...@@ -49,7 +49,7 @@ const std::string& OperatorBase::Input(const std::string& name) const { ...@@ -49,7 +49,7 @@ const std::string& OperatorBase::Input(const std::string& name) const {
} }
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <= PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
...@@ -62,7 +62,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -62,7 +62,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr");
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
...@@ -76,7 +76,7 @@ const std::string& OperatorBase::Output(const std::string& name) const { ...@@ -76,7 +76,7 @@ const std::string& OperatorBase::Output(const std::string& name) const {
} }
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <= PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
......
...@@ -167,15 +167,15 @@ class OperatorContext { ...@@ -167,15 +167,15 @@ class OperatorContext {
template <typename T> template <typename T>
const T* Input(const size_t index) const { const T* Input(const size_t index) const {
auto var = InputVar(index); auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
return &var->Get<T>(); return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const size_t index) const { T* Output(const size_t index) const {
auto var = OutputVar(index); auto var = OutputVar(index);
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
var != nullptr, var,
"Output(%d) not be nullptr, which means variable [%s] does not " "Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope", "exist in scope",
index, op_.outputs_[index]); index, op_.outputs_[index]);
...@@ -185,14 +185,14 @@ class OperatorContext { ...@@ -185,14 +185,14 @@ class OperatorContext {
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
auto var = InputVar(name); auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
return &var->Get<T>(); return &var->Get<T>();
} }
template <typename T> template <typename T>
T* Output(const std::string& name) const { T* Output(const std::string& name) const {
auto var = OutputVar(name); auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>(); return var->GetMutable<T>();
} }
...@@ -204,9 +204,9 @@ class OperatorContext { ...@@ -204,9 +204,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr, PADDLE_ENFORCE_NOT_NULL(
"MultiInput(%s:%s) should not be nullptr", var, "MultiInput(%s:%s) should not be nullptr", name,
name, sub_name); sub_name);
return &var->Get<T>(); return &var->Get<T>();
}); });
return res; return res;
...@@ -220,9 +220,9 @@ class OperatorContext { ...@@ -220,9 +220,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr, PADDLE_ENFORCE_NOT_NULL(
"MultiOutput(%s:%s) should not be nullptr", var, "MultiOutput(%s:%s) should not be nullptr", name,
name, sub_name); sub_name);
return var->GetMutable<T>(); return var->GetMutable<T>();
}); });
return res; return res;
......
...@@ -127,8 +127,8 @@ class Tensor { ...@@ -127,8 +127,8 @@ class Tensor {
memory::PODDeleter<T, Place>(place)), memory::PODDeleter<T, Place>(place)),
place_(place), place_(place),
size_(size) { size_(size) {
PADDLE_ENFORCE(ptr_ != nullptr, "Insufficient %s memory to allocation.", PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
is_cpu_place(place_) ? "CPU" : "GPU"); (is_cpu_place(place_) ? "CPU" : "GPU"));
} }
virtual size_t size() const { return size_; } virtual size_t size() const { return size_; }
......
...@@ -14,15 +14,16 @@ limitations under the License. */ ...@@ -14,15 +14,16 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
inline void Tensor::check_memory_size() const { inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Tenosr holds no memory. Call Tensor::mutable_data first."); holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.");
} }
...@@ -51,7 +52,7 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) { ...@@ -51,7 +52,7 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place) { inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0, PADDLE_ENFORCE_GT(product(dims_), 0,
"Tensor's numel must be larger than zero to call " "Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
...@@ -120,11 +121,11 @@ inline void Tensor::CopyFrom(const Tensor& src, ...@@ -120,11 +121,11 @@ inline void Tensor::CopyFrom(const Tensor& src,
template <typename T> template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>(); check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero."); PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx, PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0]; int base = product(dims_) / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
......
...@@ -36,7 +36,8 @@ TEST(Tensor, DataAssert) { ...@@ -36,7 +36,8 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]); ASSERT_EQ(what[i], msg[i]);
...@@ -111,7 +112,8 @@ TEST(Tensor, ShareDataWith) { ...@@ -111,7 +112,8 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what(); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]); ASSERT_EQ(what[i], msg[i]);
......
...@@ -22,8 +22,7 @@ class AddOp : public OperatorWithKernel { ...@@ -22,8 +22,7 @@ class AddOp : public OperatorWithKernel {
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1);
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set");
"Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of AddOp must all be set"); "Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(), PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
......
...@@ -20,17 +20,18 @@ namespace operators { ...@@ -20,17 +20,18 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1,
"Output size of OnehotCrossEntropyOp must be one"); "Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Inputs of OnehotCrossEntropyOp must all be set"); "0-th input of OnehotCrossEntropyOp should be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1),
"1-th input of OnehotCrossEntropyOp should be set");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Outputs of OnehotCrossEntropyOp must all be set"); "Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2, PADDLE_ENFORCE_EQ(ctx.Input<Tensor>(0)->dims().size(), 2);
"X's dimension must be 2."); PADDLE_ENFORCE_EQ(ctx.Output<Tensor>(0)->dims().size(), 1,
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
"label's dimension must be 1."); "label's dimension must be 1.");
ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]}); ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
} }
......
...@@ -20,13 +20,13 @@ namespace operators { ...@@ -20,13 +20,13 @@ namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
"Input size of FillZerosLikeOp must be one."); "Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Output size of AddOp must be one."); "Output size of AddOp must be one.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Input of FillZerosLikeOp must be set."); "Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Output of FillZerosLikeOp must be set."); "Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize( ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims()); ctx.Input<framework::Tensor>(0)->dims());
......
...@@ -20,10 +20,10 @@ namespace operators { ...@@ -20,10 +20,10 @@ namespace operators {
class MeanOp : public OperatorWithKernel { class MeanOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set");
"Input/Output of MeanOp must be initialized."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "output should be set");
ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1})); ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
} }
}; };
......
...@@ -70,15 +70,15 @@ class NetOp : public framework::OperatorBase { ...@@ -70,15 +70,15 @@ class NetOp : public framework::OperatorBase {
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(op); ops_.push_back(op);
} }
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) { void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed"); "Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
PADDLE_ENFORCE(pos <= ops_.size(), "Out of range"); PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op); ops_.insert(ops_.begin() + pos, op);
} }
......
...@@ -20,11 +20,11 @@ namespace operators { ...@@ -20,11 +20,11 @@ namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "inputs[0] mast be set");
PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1), "inputs[1] mast be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set"); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "outputs[0] mast be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(), PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
...@@ -20,11 +20,11 @@ namespace operators { ...@@ -20,11 +20,11 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
"Only one input is need for softmax"); "Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL, PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims().size(), 2UL,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Only one output is need for softmax"); "Only one output is need for softmax");
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims()); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
} }
...@@ -43,12 +43,12 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker { ...@@ -43,12 +43,12 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL, PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"); "Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL, PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Output of SoftmaxOpGrad should be 1"); "Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr, PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null"); "Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() == PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(), ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
......
...@@ -187,25 +187,16 @@ inline void throw_on_error(T e) { ...@@ -187,25 +187,16 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
// if two values have different data types, choose a compatible type for them. PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
template <typename T1, typename T2> paddle::string::Sprintf("" __VA_ARGS__));
struct CompatibleType {
static const bool t1_to_t2 = std::is_convertible<T1, T2>::value;
typedef typename std::conditional<t1_to_t2, T2, T1>::type type;
};
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \ PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ "enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, std::to_string(__VAL0), \ #__VAL0, #__VAL1, std::to_string(__VAL0), \
std::to_string(__VAL1), \ std::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); paddle::string::Sprintf("" __VA_ARGS__));
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
typename paddle::platform::CompatibleType<decltype(__VAL0), \
decltype(__VAL1)>::type(__VAL)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/platform/enforce.h" #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
...@@ -196,3 +198,27 @@ TEST(ENFORCE_LT, FAIL) { ...@@ -196,3 +198,27 @@ TEST(ENFORCE_LT, FAIL) {
ASSERT_TRUE(in_catch); ASSERT_TRUE(in_catch);
} }
TEST(ENFORCE_NOT_NULL, OK) {
int* a = new int;
PADDLE_ENFORCE_NOT_NULL(a);
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false;
int* a{nullptr};
try {
PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "a should not be null";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册