未验证 提交 b465bb0d 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix adaptive_pool2d/pool3d error message. test=develop (#23658)

上级 97b09687
...@@ -30,8 +30,9 @@ template <typename T> ...@@ -30,8 +30,9 @@ template <typename T>
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> { class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace"); platform::errors::InvalidArgument(
"It must use CUDAPlace when using CUDA Kernel"));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
...@@ -59,10 +60,13 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel<T> { ...@@ -59,10 +60,13 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize<int>(output->dims())); DataLayout::kNCHW, framework::vectorize<int>(output->dims()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSpatialTfSamplerForward( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSpatialTfSamplerForward(
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc, handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc, input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
output_data)); output_data),
platform::errors::InvalidArgument(
"cudnnSpatialTfSamplerForward in Op(grid_sampler) failed"));
} }
}; };
...@@ -70,8 +74,9 @@ template <typename T> ...@@ -70,8 +74,9 @@ template <typename T>
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> { class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace"); platform::errors::InvalidArgument(
"It must use CUDAPlace when using CUDA Kernel"));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
...@@ -117,7 +122,9 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> { ...@@ -117,7 +122,9 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc, input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc, input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
output_grad_data, grid_data, CudnnDataType<T>::kZero(), output_grad_data, grid_data, CudnnDataType<T>::kZero(),
grid_grad_data)); grid_grad_data),
platform::errors::InvalidArgument(
"cudnnSpatialTfSamplerBackward in Op(grid_sampler) failed"));
} }
}; };
......
...@@ -28,31 +28,55 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -28,31 +28,55 @@ class GridSampleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of GridSampleOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE(ctx->HasInput("Grid"), "Input(X) of GridSampleOp should not be null."));
"Input(Grid) of GridSampleOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Grid"), true,
PADDLE_ENFORCE(ctx->HasOutput("Output"), platform::errors::NotFound(
"Output(Output) of GridSampleOp should not be null."); "Input(Grid) of GridSampleOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Output"), true,
platform::errors::NotFound(
"Output(Output) of GridSampleOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid"); auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE(x_dims.size() == 4, PADDLE_ENFORCE_EQ(x_dims.size(), 4,
"Input(X) of GridSampleOp should be 4-D Tensor."); platform::errors::InvalidArgument(
PADDLE_ENFORCE(grid_dims.size() == 4, "Input(X) of GridSampleOp should be 4-D Tensor, but "
"Input(Grid) of GridSampleOp should be 4-D Tensor."); "received X dimension size(%d)",
x_dims.size()));
PADDLE_ENFORCE_EQ(grid_dims.size(), 4,
platform::errors::InvalidArgument(
"Input(Grid) of GridSampleOp should be 4-D Tensor, "
"but received X dimension size(%d)",
grid_dims.size()));
if (ctx->IsRuntime() || grid_dims[3] > 0) { if (ctx->IsRuntime() || grid_dims[3] > 0) {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE_EQ(
grid_dims[3], 2,
platform::errors::InvalidArgument(
"Input(Grid) dimension[3] should be 2, but received %d",
grid_dims[3]));
} }
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(
"Input(X) and Input(Grid) dims[0] should be equal."); grid_dims[0], x_dims[0],
platform::errors::InvalidArgument(
"Input(X) and Input(Grid) dimension[0] should be equal, but "
"received X dimension[0](%d) != Grid dimension[0](%d)",
x_dims[0], grid_dims[0]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2], grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."); platform::errors::InvalidArgument(
"Input(X) dims[2] and Input(Grid) dims[1] should be equal, but "
"received X dimension[2](%d) != Grid dimension[1](%d)",
x_dims[2], grid_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3], grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."); platform::errors::InvalidArgument(
"Input(X) dims[3] and Input(Grid) dims[2] should be equal, but "
"received X dimension[3](%d) != Grid dimension[2](%d)",
x_dims[3], grid_dims[2]));
} }
ctx->SetOutputDim("Output", x_dims); ctx->SetOutputDim("Output", x_dims);
......
...@@ -23,30 +23,42 @@ class KLDivLossOp : public framework::OperatorWithKernel { ...@@ -23,30 +23,42 @@ class KLDivLossOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of KLDivLossOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(X) of KLDivLossOp should not be null."));
"Input(Target) of KLDivLossOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Target"), true,
PADDLE_ENFORCE(ctx->HasOutput("Loss"), platform::errors::NotFound(
"Output(Loss) of KLDivLossOp should not be null."); "Input(Target) of KLDivLossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Loss"), true,
platform::errors::NotFound(
"Output(Loss) of KLDivLossOp should not be null."));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto dim_target = ctx->GetInputDim("Target"); auto dim_target = ctx->GetInputDim("Target");
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same."); platform::errors::InvalidArgument(
"Input(X) rank and Input(Target) rank should be "
"same, but received X rank(%d) != Target rank(%d)",
dim_x.size(), dim_target.size()));
for (int i = 0; i < dim_x.size(); i++) { for (int i = 0; i < dim_x.size(); i++) {
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) { if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], PADDLE_ENFORCE_EQ(
"Input(X) and Input(Target) should in same shape."); dim_x[i], dim_target[i],
platform::errors::InvalidArgument(
"Input(X) and Input(Target) should in same shape. but received "
"X dimension[%d](%d) != Target dimension[%d](%d)",
i, dim_x[i], i, dim_target[i]));
} }
} }
auto reduction = ctx->Attrs().Get<std::string>("reduction"); auto reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE( auto reduction_valid = "mean" == reduction || "sum" == reduction ||
"mean" == reduction || "sum" == reduction || "batchmean" == reduction || "batchmean" == reduction || "none" == reduction;
"none" == reduction, PADDLE_ENFORCE_EQ(
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."); reduction_valid, true,
platform::errors::InvalidArgument(
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."));
if ("none" == reduction) { if ("none" == reduction) {
ctx->SetOutputDim("Loss", dim_x); ctx->SetOutputDim("Loss", dim_x);
...@@ -123,10 +135,15 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { ...@@ -123,10 +135,15 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null"); ctx->HasInput("X"), true,
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), platform::errors::NotFound("Input(X) should not be null"));
"Input(Loss@GRAD) should not be null"); PADDLE_ENFORCE_EQ(
ctx->HasInput("Target"), true,
platform::errors::NotFound("Input(Target) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Loss")), true,
platform::errors::NotFound("Input(Loss@GRAD) should not be null"));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x); ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
......
...@@ -26,26 +26,45 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -26,26 +26,45 @@ class SpectralNormOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE_EQ(
"Input(Weight) of SpectralNormOp should not be null."); ctx->HasInput("Weight"), true,
PADDLE_ENFORCE(ctx->HasInput("U"), platform::errors::NotFound(
"Input(U) of SpectralNormOp should not be null."); "Input(Weight) of SpectralNormOp should not be null."));
PADDLE_ENFORCE(ctx->HasInput("V"), PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true,
"Input(V) of SpectralNormOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Input(U) of SpectralNormOp should not be null."));
"Output(Out) of SpectralNormOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true,
platform::errors::NotFound(
"Input(V) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of SpectralNormOp should not be null."));
auto dim_weight = ctx->GetInputDim("Weight"); auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size(); auto rank_weight = dim_weight.size();
PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5, PADDLE_ENFORCE_GE(rank_weight, 2,
"The rank of Input(Weights) can only be 2, 3," platform::errors::InvalidArgument(
"4, 5 for fc, conv1d, conv2d, conv3d layers."); "The rank of Input(Weights) should be greater equal "
"than 2, but received Weight rank(%d)",
rank_weight));
PADDLE_ENFORCE_LE(rank_weight, 5,
platform::errors::InvalidArgument(
"The rank of Input(Weights) should be less equal "
"than 5, but received Weight rank(%d)",
rank_weight));
int dim = ctx->Attrs().Get<int>("dim"); int dim = ctx->Attrs().Get<int>("dim");
int power_iters = ctx->Attrs().Get<int>("power_iters"); int power_iters = ctx->Attrs().Get<int>("power_iters");
PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1"); auto dim_valid = dim == 0 || dim == 1;
PADDLE_ENFORCE(power_iters >= 0, PADDLE_ENFORCE_EQ(
"Attr(power_iters) should be larger equal then 0"); dim_valid, true,
platform::errors::InvalidArgument(
"Attr(dim) can only be 0 or 1, but received %d", dim));
PADDLE_ENFORCE_GE(
power_iters, 0,
platform::errors::InvalidArgument(
"Attr(power_iters) should be greater equal then 0, but received %d",
power_iters));
int h = dim_weight[dim]; int h = dim_weight[dim];
int w = 1; int w = 1;
...@@ -59,15 +78,22 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -59,15 +78,22 @@ class SpectralNormOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) { if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0], h, PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to " platform::errors::InvalidArgument(
"Input(Weight) dims[Attr(dim)]"); "Input(U) dimension[0] should be equal to "
"Input(Weight) dimension[Attr(dim)], but received "
"U dimension[0](%d) != Weight dimension[%d](%d)",
dim_u[0], dim, h));
} }
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) { if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_v[0], w, dim_v[0], w,
"Input(V) dims[0] should be equal to " platform::errors::InvalidArgument(
"the product of Input(Weight) dims except dims[Attr(dim)]"); "Input(V) dimension[0] should be equal to the product of "
"Input(Weight) dimension except dimension[Attr(dim)], but "
"received V dimension[0](%d) != product of Input(Weight) "
"dimension(%d)",
dim_v[0], w));
} }
ctx->SetOutputDim("Out", dim_weight); ctx->SetOutputDim("Out", dim_weight);
...@@ -194,11 +220,18 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { ...@@ -194,11 +220,18 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null"); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) should not be null"); ctx->HasInput("Weight"), true,
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) should not be null"); platform::errors::NotFound("Input(Weight) should not be null"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(
"Input(Out@GRAD) should not be null"); ctx->HasInput("U"), true,
platform::errors::NotFound("Input(U) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("V"), true,
platform::errors::NotFound("Input(V) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
auto dim_x = ctx->GetInputDim("Weight"); auto dim_x = ctx->GetInputDim("Weight");
if (ctx->HasOutput(framework::GradVarName("Weight"))) { if (ctx->HasOutput(framework::GradVarName("Weight"))) {
ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x); ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x);
......
...@@ -27,26 +27,45 @@ class TemporalShiftOp : public framework::OperatorWithKernel { ...@@ -27,26 +27,45 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of TemporalShiftOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, "Input(X) of TemporalShiftOp should not be null."));
"Output(Out) of TemporalShiftOp should not be null."); PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of TemporalShiftOp should not be null."));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, PADDLE_ENFORCE_EQ(dim_x.size(), 4,
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."); platform::errors::InvalidArgument(
"Input(X) rank should be 4 in shape of [N*T, C, H, "
"W], but received X rank(%d)",
dim_x.size()));
int seg_num = ctx->Attrs().Get<int>("seg_num"); int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio"); float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0."); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT(shift_ratio, 0., seg_num, 0,
"Attr(shift_ratio) should be greater than 0"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_LT(shift_ratio, 0.5, "Attr(seg_num) should be greater than 0, but received %d",
"Attr(shift_ratio) should be less than 0.5"); seg_num));
PADDLE_ENFORCE_GT(
shift_ratio, 0.,
platform::errors::InvalidArgument(
"Attr(shift_ratio) should be greater than 0, but received %d",
shift_ratio));
PADDLE_ENFORCE_LT(
shift_ratio, 0.5,
platform::errors::InvalidArgument(
"Attr(shift_ratio) should be less than 0.5, but received %d",
shift_ratio));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
dim_x[0] % seg_num, 0, platform::errors::InvalidArgument(
"Input(X) dims[0] should be divided exactly by Attr(seg_num)."); "Input(X) dimension[0] should be divided exactly "
"by Attr(seg_num), but received X dimension[0](%d) "
"mod seg_num(%d) != 0",
dim_x[0], seg_num));
} }
ctx->SetOutputDim("Out", dim_x); ctx->SetOutputDim("Out", dim_x);
......
...@@ -90,8 +90,9 @@ template <typename T> ...@@ -90,8 +90,9 @@ template <typename T>
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device."); platform::errors::InvalidArgument(
"This kernel only runs on GPU device."));
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
......
...@@ -21,7 +21,7 @@ from .layer_function_generator import templatedoc ...@@ -21,7 +21,7 @@ from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable, in_dygraph_mode from ..framework import Variable, in_dygraph_mode
from .. import core from .. import core
from ..data_feeder import check_variable_and_dtype from ..data_feeder import check_variable_and_dtype, check_type
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import NumpyArrayInitializer, Constant from ..initializer import NumpyArrayInitializer, Constant
from .. import core from .. import core
...@@ -1580,6 +1580,10 @@ def kldiv_loss(x, target, reduction='mean', name=None): ...@@ -1580,6 +1580,10 @@ def kldiv_loss(x, target, reduction='mean', name=None):
loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean') loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean')
""" """
helper = LayerHelper('kldiv_loss', **locals()) helper = LayerHelper('kldiv_loss', **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'kldiv_loss')
check_variable_and_dtype(target, 'target', ['float32', 'float64'],
'kldiv_loss')
check_type(reduction, 'reduction', str, 'kldiv_loss')
loss = helper.create_variable_for_type_inference(dtype=x.dtype) loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='kldiv_loss', type='kldiv_loss',
......
...@@ -2361,6 +2361,12 @@ def adaptive_pool2d(input, ...@@ -2361,6 +2361,12 @@ def adaptive_pool2d(input,
pool_size=[3, 3], pool_size=[3, 3],
pool_type='max') pool_type='max')
""" """
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'adaptive_pool2d')
check_type(pool_type, 'pool_type', str, 'adaptive_pool2d')
check_type(pool_size, 'pool_size', (int, list, tuple), 'adaptive_pool2d')
check_type(require_index, 'require_index', bool, 'adaptive_pool2d')
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
...@@ -2516,6 +2522,12 @@ def adaptive_pool3d(input, ...@@ -2516,6 +2522,12 @@ def adaptive_pool3d(input,
pool_size=[3, 3, 3], pool_size=[3, 3, 3],
pool_type='max') pool_type='max')
""" """
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'adaptive_pool3d')
check_type(pool_type, 'pool_type', str, 'adaptive_pool3d')
check_type(pool_size, 'pool_size', (int, list, tuple), 'adaptive_pool3d')
check_type(require_index, 'require_index', bool, 'adaptive_pool3d')
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
...@@ -3568,6 +3580,11 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): ...@@ -3568,6 +3580,11 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
x = fluid.layers.spectral_norm(weight=weight, dim=1, power_iters=2) x = fluid.layers.spectral_norm(weight=weight, dim=1, power_iters=2)
""" """
helper = LayerHelper('spectral_norm', **locals()) helper = LayerHelper('spectral_norm', **locals())
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'spectral_norm')
check_type(dim, 'dim', int, 'spectral_norm')
check_type(power_iters, 'power_iters', int, 'spectral_norm')
check_type(eps, 'eps', float, 'spectral_norm')
dtype = weight.dtype dtype = weight.dtype
# create intput and parameters # create intput and parameters
...@@ -12246,6 +12263,9 @@ def grid_sampler(x, grid, name=None): ...@@ -12246,6 +12263,9 @@ def grid_sampler(x, grid, name=None):
""" """
helper = LayerHelper("grid_sampler", **locals()) helper = LayerHelper("grid_sampler", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sampler')
check_variable_and_dtype(grid, 'grid', ['float32', 'float64'],
'grid_sampler')
if not isinstance(x, Variable): if not isinstance(x, Variable):
return ValueError("The x should be a Variable") return ValueError("The x should be a Variable")
...@@ -12601,6 +12621,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -12601,6 +12621,9 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2) out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
""" """
helper = LayerHelper("temporal_shift", **locals()) helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册