提交 2af35002 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix some enforce (#3301)

* fix some enforce

* remove compatible_type to avoid compile error

* remove shared_ptr

* fix tensor error msg
上级 95fb1617
......@@ -34,7 +34,7 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif
const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr,
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_,
"Input Output Indices could not be nullptr");
auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
......@@ -49,7 +49,7 @@ const std::string& OperatorBase::Input(const std::string& name) const {
}
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
......@@ -62,7 +62,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
}
const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr");
auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name);
......@@ -76,7 +76,7 @@ const std::string& OperatorBase::Output(const std::string& name) const {
}
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
......
......@@ -167,15 +167,15 @@ class OperatorContext {
template <typename T>
const T* Input(const size_t index) const {
auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
}
template <typename T>
T* Output(const size_t index) const {
auto var = OutputVar(index);
PADDLE_ENFORCE(
var != nullptr,
PADDLE_ENFORCE_NOT_NULL(
var,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope",
index, op_.outputs_[index]);
......@@ -185,14 +185,14 @@ class OperatorContext {
template <typename T>
const T* Input(const std::string& name) const {
auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
}
template <typename T>
T* Output(const std::string& name) const {
auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
}
......@@ -204,9 +204,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiInput(%s:%s) should not be nullptr",
name, sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiInput(%s:%s) should not be nullptr", name,
sub_name);
return &var->Get<T>();
});
return res;
......@@ -220,9 +220,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiOutput(%s:%s) should not be nullptr",
name, sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr", name,
sub_name);
return var->GetMutable<T>();
});
return res;
......
......@@ -127,8 +127,8 @@ class Tensor {
memory::PODDeleter<T, Place>(place)),
place_(place),
size_(size) {
PADDLE_ENFORCE(ptr_ != nullptr, "Insufficient %s memory to allocation.",
is_cpu_place(place_) ? "CPU" : "GPU");
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU"));
}
virtual size_t size() const { return size_; }
......
......@@ -14,15 +14,16 @@ limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
......@@ -51,7 +52,7 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0,
PADDLE_ENFORCE_GT(product(dims_), 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
......@@ -120,11 +121,11 @@ inline void Tensor::CopyFrom(const Tensor& src,
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
......
......@@ -36,7 +36,8 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
"holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......@@ -111,7 +112,8 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
"holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......
......@@ -22,8 +22,7 @@ class AddOp : public OperatorWithKernel {
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1);
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of AddOp must all be set");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
......
......@@ -20,17 +20,18 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2,
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1,
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1,
"Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"0-th input of OnehotCrossEntropyOp should be set");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1),
"1-th input of OnehotCrossEntropyOp should be set");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
"X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>(0)->dims().size(), 2);
PADDLE_ENFORCE_EQ(ctx.Output<Tensor>(0)->dims().size(), 1,
"label's dimension must be 1.");
ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
}
......
......@@ -20,13 +20,13 @@ namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
"Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Output size of AddOp must be one.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr,
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
......
......@@ -20,10 +20,10 @@ namespace operators {
class MeanOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
"Input/Output of MeanOp must be initialized.");
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "output should be set");
ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
}
};
......
......@@ -70,15 +70,15 @@ class NetOp : public framework::OperatorBase {
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(op);
}
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE(pos <= ops_.size(), "Out of range");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op);
}
......
......@@ -20,11 +20,11 @@ namespace operators {
class SGDOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "inputs[0] mast be set");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1), "inputs[1] mast be set");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "outputs[0] mast be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of SGD Op's dimension must be same.");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -20,11 +20,11 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
"Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims().size(), 2UL,
"The input of softmax op must be matrix");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Only one output is need for softmax");
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
}
......@@ -43,12 +43,12 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
class SoftmaxOpGrad : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL,
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
......
......@@ -187,25 +187,16 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
// if two values have different data types, choose a compatible type for them.
template <typename T1, typename T2>
struct CompatibleType {
static const bool t1_to_t2 = std::is_convertible<T1, T2>::value;
typedef typename std::conditional<t1_to_t2, T2, T1>::type type;
};
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__));
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, std::to_string(__VAL0), \
std::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
typename paddle::platform::CompatibleType<decltype(__VAL0), \
decltype(__VAL1)>::type(__VAL)
} // namespace platform
} // namespace paddle
......@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/enforce.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
......@@ -196,3 +198,27 @@ TEST(ENFORCE_LT, FAIL) {
ASSERT_TRUE(in_catch);
}
TEST(ENFORCE_NOT_NULL, OK) {
int* a = new int;
PADDLE_ENFORCE_NOT_NULL(a);
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false;
int* a{nullptr};
try {
PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "a should not be null";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册