未验证 提交 0a46d345 编写于 作者: T Tao Luo 提交者: GitHub

refine some PADDLE_ENFORCE codes for unify PADDLE_ASSERT_MSG (#19607)

test=develop
上级 a3a4b6e5
...@@ -341,7 +341,7 @@ class ExecutionContext { ...@@ -341,7 +341,7 @@ class ExecutionContext {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const { const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()), true);
return *reinterpret_cast<const platform::CUDADeviceContext*>( return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_); &device_context_);
} }
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -36,10 +36,16 @@ using framework::Tensor; ...@@ -36,10 +36,16 @@ using framework::Tensor;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 || if (index.dims().size() == 2) {
(index.dims().size() == 2 && index.dims()[1] == 1)); PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in gather_op.");
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in gather_op.");
}
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -19,7 +19,6 @@ https://github.com/caffe2/caffe2/blob/master/caffe2/operators/lstm_unit_op_gpu.c ...@@ -19,7 +19,6 @@ https://github.com/caffe2/caffe2/blob/master/caffe2/operators/lstm_unit_op_gpu.c
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/operators/lstm_unit_op.h" #include "paddle/fluid/operators/lstm_unit_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
......
...@@ -666,7 +666,11 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, ...@@ -666,7 +666,11 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
mat_b.data<T>(), beta, mat_out->data<T>()); mat_b.data<T>(), beta, mat_out->data<T>());
} else { } else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0,
"dim_a.batch_size should be equal to dim_b.batch_size, or "
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_);
this->template BatchedGEMM<T>( this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(), mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
......
...@@ -90,12 +90,16 @@ template <typename T, typename IndexT = int> ...@@ -90,12 +90,16 @@ template <typename T, typename IndexT = int>
void GPUScatterAssign(const framework::ExecutionContext& context, void GPUScatterAssign(const framework::ExecutionContext& context,
const Tensor& src, const Tensor& index, Tensor* output, const Tensor& src, const Tensor& index, Tensor* output,
bool overwrite = true) { bool overwrite = true) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
const auto& ctx = context.device_context(); const auto& ctx = context.device_context();
PADDLE_ENFORCE(index.dims().size() == 1 || if (index.dims().size() == 2) {
(index.dims().size() == 2 && index.dims()[1] == 1)); PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op.");
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op.");
}
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -73,10 +73,16 @@ elementwise_inner_add(const framework::ExecutionContext& ctx, ...@@ -73,10 +73,16 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 || if (index.dims().size() == 2) {
(index.dims().size() == 2 && index.dims()[1] == 1)); PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op.");
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op.");
}
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -88,7 +94,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -88,7 +94,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
...@@ -105,10 +111,12 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -105,10 +111,12 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.device_context().GetPlace())); PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
true);
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 || PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1)); (index.dims().size() == 2 && index.dims()[1] == 1),
"");
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -122,7 +130,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -122,7 +130,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册