提交 3647704a 编写于 作者: S seiriosPlus

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize/large_scale_kv_spped

...@@ -19,13 +19,17 @@ namespace paddle { ...@@ -19,13 +19,17 @@ namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(proto::VarType::Type type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(holder_, platform::errors::PreconditionNotMet(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); "Tensor holds no memory. "
"Call Tensor::mutable_data firstly."));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
numel() * SizeOfType(type()), memory_size(), numel() * SizeOfType(type()), memory_size(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " platform::errors::PreconditionNotMet(
"first to re-allocate memory.\n" "Tensor's dimension is out of bound."
"or maybe the required data-type mismatches the data already stored."); "Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is d%, memory's size is %d.",
numel() * SizeOfType(type()), memory_size()));
} }
Tensor::Tensor(const proto::VarType::Type& dtype) : type_(dtype), offset_(0) {} Tensor::Tensor(const proto::VarType::Type& dtype) : type_(dtype), offset_(0) {}
...@@ -37,15 +41,21 @@ size_t Tensor::memory_size() const { ...@@ -37,15 +41,21 @@ size_t Tensor::memory_size() const {
void* Tensor::mutable_data(const platform::Place& place, void* Tensor::mutable_data(const platform::Place& place,
proto::VarType::Type type, size_t requested_size) { proto::VarType::Type type, size_t requested_size) {
type_ = type; type_ = type;
PADDLE_ENFORCE_GE(numel(), 0, PADDLE_ENFORCE_GE(
"When calling this method, the Tensor's numel must be " numel(), 0,
"equal or larger than zero. " platform::errors::PreconditionNotMet(
"Please check Tensor::dims, or Tensor::Resize has been " "The Tensor's element number must be equal or greater than zero. "
"called first. The Tensor's shape is [", "The Tensor's shape is [",
dims(), "] now"); dims(), "] now"));
size_t size = numel() * SizeOfType(type); size_t size = numel() * SizeOfType(type);
if (requested_size) { if (requested_size) {
PADDLE_ENFORCE_GE(requested_size, size); PADDLE_ENFORCE_GE(
requested_size, size,
platform::errors::InvalidArgument(
"The requested memory size is less than the memory size of Tensor. "
"But received requested memory size is d%, "
"memory size of Tensor is %d.",
requested_size, size));
size = requested_size; size = requested_size;
} }
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
...@@ -62,8 +72,8 @@ void* Tensor::mutable_data(const platform::Place& place, ...@@ -62,8 +72,8 @@ void* Tensor::mutable_data(const platform::Place& place,
void* Tensor::mutable_data(const platform::Place& place, void* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) { size_t requested_size) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(this->holder_, platform::errors::PreconditionNotMet(
this->holder_, "Cannot invoke mutable data if current hold nothing."); "The tensor is not initialized."));
return mutable_data(place, type_, requested_size); return mutable_data(place, type_, requested_size);
} }
...@@ -75,12 +85,20 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) { ...@@ -75,12 +85,20 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) {
Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const { Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size(); check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0, PADDLE_ENFORCE_GE(
"The start row index must be greater than 0."); begin_idx, 0,
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound."); platform::errors::OutOfRange("The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(
end_idx, dims_[0],
platform::errors::OutOfRange("The end row index is out of bound."));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
begin_idx, end_idx, begin_idx, end_idx,
"The start row index must be lesser than the end row index."); platform::errors::InvalidArgument(
"The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.",
begin_idx, end_idx));
if (dims_[0] == 1) { if (dims_[0] == 1) {
return *this; return *this;
......
...@@ -131,13 +131,17 @@ class Tensor { ...@@ -131,13 +131,17 @@ class Tensor {
const platform::Place& place() const { const platform::Place& place() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::place() is called."); holder_,
platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::place() is called."));
return holder_->place(); return holder_->place();
} }
proto::VarType::Type type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_,
platform::errors::PreconditionNotMet(
"Tensor not initialized yet when Tensor::type() is called."));
return type_; return type_;
} }
......
...@@ -43,9 +43,13 @@ inline T* Tensor::data() { ...@@ -43,9 +43,13 @@ inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType(); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType();
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s", valid, true,
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType())); platform::errors::InvalidArgument(
"Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_),
DataTypeToString(DataTypeTrait<T>::DataType())));
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
...@@ -69,9 +73,12 @@ inline T* Tensor::mutable_data(const platform::Place& place, ...@@ -69,9 +73,12 @@ inline T* Tensor::mutable_data(const platform::Place& place,
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
int rank = src.dims().size(); int rank = src.dims().size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
rank, 2, rank, 2, platform::errors::InvalidArgument(
"'ReshapeToMatrix()' is only used for flatten high rank " "'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. Can not be used in reshaping vectors."); "tensors to matrixs. The dimensions of Tensor must be "
"greater or equal than 2. "
"But received dimensions of Tensor is %d",
rank));
if (rank == 2) { if (rank == 2) {
return src; return src;
} }
......
...@@ -41,7 +41,7 @@ TEST(Tensor, DataAssert) { ...@@ -41,7 +41,7 @@ TEST(Tensor, DataAssert) {
std::string ex_msg = err.what(); std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("holder_ should not be null") != std::string::npos); EXPECT_TRUE(ex_msg.find("holder_ should not be null") != std::string::npos);
EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call " EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call "
"Tensor::mutable_data first.") != "Tensor::mutable_data firstly.") !=
std::string::npos); std::string::npos);
} }
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
...@@ -157,7 +157,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -157,7 +157,7 @@ TEST(Tensor, ShareDataWith) {
EXPECT_TRUE(ex_msg.find("holder_ should not be null") != EXPECT_TRUE(ex_msg.find("holder_ should not be null") !=
std::string::npos); std::string::npos);
EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call " EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call "
"Tensor::mutable_data first.") != "Tensor::mutable_data firstly.") !=
std::string::npos); std::string::npos);
} }
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
......
...@@ -363,9 +363,12 @@ if(WITH_MKLDNN) ...@@ -363,9 +363,12 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC}) inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC})
# MobileNetV1 FP32 vs. Quant INT8 # MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim Quant unit tests on Linux
set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2") set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2")
set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8") set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8")
if(NOT LINUX)
download_quant_data(${QUANT2_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf.tar.gz") download_quant_data(${QUANT2_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
endif(NOT LINUX)
download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz") download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH}) inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int64_t> CorrelationOutputSize(int batch, int input_height,
int input_width, int stride1,
int stride2, int kernel_size,
int pad_size,
int max_displacement) {
std::vector<int64_t> output_shape({batch});
int kernel_radius = (kernel_size - 1) / 2;
int border_radius = kernel_radius + max_displacement;
int padded_input_height = input_height + 2 * pad_size;
int padded_input_width = input_width + 2 * pad_size;
int output_channel = ((max_displacement / stride2) * 2 + 1) *
((max_displacement / stride2) * 2 + 1);
output_shape.push_back(output_channel);
int output_height =
std::ceil(static_cast<float>(padded_input_height - 2 * border_radius) /
static_cast<float>(stride1));
int output_width =
std::ceil(static_cast<float>(padded_input_width - 2 * border_radius) /
static_cast<float>(stride1));
output_shape.push_back(output_height);
output_shape.push_back(output_width);
return output_shape;
}
class CorrelationOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input1", "Input is a 4-D Tensor with shape [N, C, H, W]");
AddInput("Input2", "Input is a 4-D Tensor with shape [N, C, H, W]");
AddOutput("Output",
"(Tensor) The output tensor of correlation operator. "
"It has same data fromat and data type as the Input.");
AddAttr<int>("pad_size", "pad size for input1 and input2");
AddAttr<int>("kernel_size", "kernel size of input1 and input2");
AddAttr<int>("max_displacement", "max displacement of input1 and input2");
AddAttr<int>("stride1", "Input1 stride");
AddAttr<int>("stride2", "Input2 stride");
AddAttr<int>("corr_type_multiply", "correlation coefficient").SetDefault(1);
AddComment(
R"DOC(Correlation of two feature map. Only support NCHW data format.)DOC");
}
};
class CorrelationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input1"), "Input", "X", "CorrelationOp");
OP_INOUT_CHECK(ctx->HasInput("Input2"), "Input", "Y", "CorrelationOp");
int stride1 = ctx->Attrs().Get<int>("stride1");
int stride2 = ctx->Attrs().Get<int>("stride2");
int max_displacement = ctx->Attrs().Get<int>("max_displacement");
int pad_size = ctx->Attrs().Get<int>("pad_size");
int kernel_size = ctx->Attrs().Get<int>("kernel_size");
auto in_dims = ctx->GetInputDim("Input1");
auto in2_dims = ctx->GetInputDim("Input2");
PADDLE_ENFORCE_EQ(in_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Input(X) of CorrelationOp must be 4 dims."
"But received dims is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(in2_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Input(Y) of CorrelationOp must be 4 dims."
"But received dims is %d.",
in2_dims.size()));
std::vector<int64_t> output_shape =
CorrelationOutputSize(in_dims[0], in_dims[2], in_dims[3], stride1,
stride2, kernel_size, pad_size, max_displacement);
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input1");
PADDLE_ENFORCE_EQ(input_data_type, ctx.Input<Tensor>("Input2")->type(),
platform::errors::InvalidArgument(
"X and Y shoule have the same datatype"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
class CorrelationOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("correlation_grad");
op->SetInput("Input1", this->Input("Input1"));
op->SetInput("Input2", this->Input("Input2"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input1"), this->InputGrad("Input1"));
op->SetOutput(framework::GradVarName("Input2"), this->InputGrad("Input2"));
op->SetAttrMap(this->Attrs());
}
};
class CorrelationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input1"), "Input", "X", "CorrelationOp");
OP_INOUT_CHECK(ctx->HasInput("Input2"), "Input", "Y", "CorrelationOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Output")), "Input",
"Output@GRAD", "CorrelationGradOp");
auto in1_dims = ctx->GetInputDim("Input1");
auto in2_dims = ctx->GetInputDim("Input2");
ctx->SetOutputDim(framework::GradVarName("Input1"), in1_dims);
ctx->SetOutputDim(framework::GradVarName("Input2"), in2_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace());
}
};
template <typename T>
class CorrelationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::Unimplemented("Correlation only supports GPU now."));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(correlation, ops::CorrelationOp, ops::CorrelationOpMaker,
ops::CorrelationOpGradMaker<paddle::framework::OpDesc>,
ops::CorrelationOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(correlation_grad, ops::CorrelationOpGrad);
REGISTER_OP_CPU_KERNEL(correlation, ops::CorrelationKernel<float>,
ops::CorrelationKernel<double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
#define THREADS_PER_BLOCK 32
#define FULL_MASK 0xffffffff
using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(FULL_MASK, val, offset);
}
return val;
}
template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
val = warpReduceSum(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
if (wid == 0) val = warpReduceSum(val);
return val;
}
template <typename T>
__global__ void set_zero(T *x, int num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x)
x[i] = static_cast<T>(0);
}
template <typename T>
__global__ void channel_first(const T *input, T *rinput, const int channel,
const int height, const int width,
const int pad_size) {
int n = blockIdx.x;
int h = blockIdx.y;
int w = blockIdx.z;
int ch_off = threadIdx.x;
T value;
int dimchw = channel * height * width;
int dimhw = height * width;
int p_dimw = (width + 2 * pad_size);
int p_dimh = (height + 2 * pad_size);
int p_dimchw = channel * p_dimw * p_dimh;
int p_dimcw = channel * p_dimw;
for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) {
value = input[n * dimchw + c * dimhw + h * width + w];
rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel +
c] = value;
}
}
template <typename T>
__global__ void correlation_forward(
T *output, const int output_channel, const int output_height,
const int output_width, const T *rinput1, const int input_channel,
const int input_height, const int input_width, const T *rinput2,
const int pad_size, const int kernel_size, const int max_displacement,
const int stride1, const int stride2) {
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int n = blockIdx.x;
int h1 = blockIdx.y * stride1 + max_displacement;
int w1 = blockIdx.z * stride1 + max_displacement;
int c = threadIdx.x;
int p_dimchw = p_input_height * p_input_width * input_channel;
int p_dimcw = p_input_width * input_channel;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int nelems = kernel_size * kernel_size * p_dimc;
for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
int w2 = w1 + ti * stride2;
int h2 = h1 + tj * stride2;
T acc0 = 0;
for (int j = -kernel_rad; j <= kernel_rad; ++j) {
for (int i = -kernel_rad; i <= kernel_rad; ++i) {
for (int ch = c; ch < p_dimc; ch += blockDim.x) {
int index1 =
n * p_dimchw + (h1 + j) * p_dimcw + (w1 + i) * p_dimc + ch;
int index2 =
n * p_dimchw + (h2 + j) * p_dimcw + (w2 + i) * p_dimc + ch;
acc0 += static_cast<T>(rinput1[index1] * rinput2[index2]);
}
}
}
if (blockDim.x == warpSize) {
__syncwarp();
acc0 = warpReduceSum(acc0);
} else {
__syncthreads();
acc0 = blockReduceSum(acc0);
}
if (threadIdx.x == 0) {
int tc = (tj + displacement_rad) * displacement_size +
(ti + displacement_rad);
const int t_index =
n * t_dimchw + tc * t_dimhw + blockIdx.y * t_dimw + blockIdx.z;
output[t_index] = static_cast<T>(acc0 / nelems);
}
}
}
}
// class CorrelationKernel<platform::CUDADeviceContext, T>
template <typename T>
class CorrelationCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"Correlation only supports GPU now."));
auto *input1 = ctx.Input<Tensor>("Input1");
auto *input2 = ctx.Input<Tensor>("Input2");
int pad_size = ctx.Attr<int>("pad_size");
int kernel_size = ctx.Attr<int>("kernel_size");
int stride1 = ctx.Attr<int>("stride1");
int stride2 = ctx.Attr<int>("stride2");
int max_displacement = ctx.Attr<int>("max_displacement");
int corr_type_multiply = ctx.Attr<int>("corr_type_multiply");
auto *output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// base on input1, NCHW
auto in_dims = input1->dims();
int N = in_dims[0];
int C = in_dims[1];
int H = in_dims[2];
int W = in_dims[3];
int padded_input_height = H + 2 * pad_size;
int padded_input_width = W + 2 * pad_size;
Tensor rinput1 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>(
{N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput1.mutable_data<T>(ctx.GetPlace());
Tensor rinput2 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>(
{N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput2.mutable_data<T>(ctx.GetPlace());
set_zero<<<(rinput1.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>(
rinput1.data<T>(), rinput1.numel());
set_zero<<<(rinput2.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>(
rinput2.data<T>(), rinput2.numel());
set_zero<<<(output->numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>(
output->data<T>(), output->numel());
auto out_dims = output->dims();
int OC = out_dims[1];
int OH = out_dims[2];
int OW = out_dims[3];
dim3 blocks_grid(N, H, W);
dim3 threads_block(THREADS_PER_BLOCK);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(
input1->data<T>(), rinput1.data<T>(), C, H, W, pad_size);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(
input2->data<T>(), rinput2.data<T>(), C, H, W, pad_size);
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(N, OH, OW);
correlation_forward<
T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(
output->data<T>(), OC, OH, OW, rinput1.data<T>(), C, H, W,
rinput2.data<T>(), pad_size, kernel_size, max_displacement, stride1,
stride2);
}
};
template <typename T>
__global__ void correlation_backward_input1(
int item, T *grad_input1, const int input_channel, const int input_height,
const int input_width, const T *grad_output, const int output_channel,
const int output_height, const int output_width, const T *rinput2,
const int pad_size, const int kernel_size, const int max_displacement,
const int stride1, const int stride2) {
int n = item;
int h = blockIdx.x * stride1 + pad_size;
int w = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int xmin = (w - kernel_rad - max_displacement) / stride1;
int ymin = (h - kernel_rad - max_displacement) / stride1;
int xmax = (w + kernel_rad - max_displacement) / stride1;
int ymax = (h + kernel_rad - max_displacement) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) {
return;
}
if (xmin > xmax || ymin > ymax) {
return;
}
xmin = max(0, xmin);
xmax = min(output_width - 1, xmax);
ymin = max(0, ymin);
ymax = min(output_height - 1, ymax);
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int p_dimchw = input_channel * p_input_height * p_input_width;
int p_dimcw = input_channel * p_input_width;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int o_dimchw = input_channel * input_height * input_width;
int o_dimhw = input_height * input_width;
int o_dimw = input_width;
int nelems = kernel_size * kernel_size * input_channel;
__shared__ T prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int index2 = n * p_dimchw + (h + j2) * p_dimcw + (w + i2) * p_dimc + c;
T val2 = rinput2[index2];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i;
prod_sum[tch_off] += grad_output[t_index] * val2;
}
}
}
__syncthreads();
if (tch_off == 0) {
T reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; index++) {
reduce_sum += prod_sum[index];
}
const int index1 =
n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size);
grad_input1[index1] = static_cast<T>(reduce_sum / nelems);
}
}
template <typename T>
__global__ void correlation_backward_input2(
int item, T *grad_input2, const int input_channel, const int input_height,
const int input_width, const T *grad_output, const int output_channel,
const int output_height, const int output_width, const T *rinput1,
const int pad_size, const int kernel_size, const int max_displacement,
const int stride1, const int stride2) {
int n = item;
int h = blockIdx.x * stride1 + pad_size;
int w = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int p_dimchw = input_channel * p_input_height * p_input_width;
int p_dimcw = input_channel * p_input_width;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int o_dimchw = input_channel * input_height * input_width;
int o_dimhw = input_height * input_width;
int o_dimw = input_width;
int nelems = kernel_size * kernel_size * input_channel;
__shared__ T prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int xmin = (w - kernel_rad - max_displacement - i2) / stride1;
int ymin = (h - kernel_rad - max_displacement - j2) / stride1;
int xmax = (w + kernel_rad - max_displacement - i2) / stride1;
int ymax = (h + kernel_rad - max_displacement - j2) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) {
continue;
}
if (xmin > xmax || ymin > ymax) {
continue;
}
xmin = max(0, xmin);
xmax = min(output_width - 1, xmax);
ymin = max(0, ymin);
ymax = min(output_height - 1, ymax);
int index1 = n * p_dimchw + (h - j2) * p_dimcw + (w - i2) * p_dimc + c;
T val1 = rinput1[index1];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i;
prod_sum[tch_off] += grad_output[t_index] * val1;
}
}
}
__syncthreads();
if (tch_off == 0) {
T reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; index++) {
reduce_sum += prod_sum[index];
}
const int index2 =
n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size);
grad_input2[index2] = static_cast<T>(reduce_sum / nelems);
}
}
template <typename T>
class CorrelationCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"Correlation only supports GPU now."));
const auto *input1 = ctx.Input<Tensor>("Input1");
const auto *input2 = ctx.Input<Tensor>("Input2");
const auto *grad_output =
ctx.Input<Tensor>(framework::GradVarName("Output"));
const int pad_size = ctx.Attr<int>("pad_size");
const int kernel_size = ctx.Attr<int>("kernel_size");
const int stride1 = ctx.Attr<int>("stride1");
const int stride2 = ctx.Attr<int>("stride2");
const int max_displacement = ctx.Attr<int>("max_displacement");
const int corr_type_multiply = ctx.Attr<int>("corr_type_multiply");
auto *grad_input1 = ctx.Output<Tensor>(framework::GradVarName("Input1"));
grad_input1->mutable_data<T>(ctx.GetPlace());
auto *grad_input2 = ctx.Output<Tensor>(framework::GradVarName("Input2"));
grad_input2->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in_dims = input1->dims();
int N = in_dims[0];
int C = in_dims[1];
int H = in_dims[2];
int W = in_dims[3];
int padded_input_height = H + 2 * pad_size;
int padded_input_width = W + 2 * pad_size;
Tensor rinput1 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>(
{N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput1.mutable_data<T>(ctx.GetPlace());
Tensor rinput2 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>(
{N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput2.mutable_data<T>(ctx.GetPlace());
set_zero<<<(rinput1.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>(
rinput1.data<T>(), rinput1.numel());
set_zero<<<(rinput2.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>(
rinput2.data<T>(), rinput2.numel());
set_zero<<<(grad_input1->numel() + 512 - 1) / 512, 512, 0,
dev_ctx.stream()>>>(grad_input1->data<T>(),
grad_input1->numel());
set_zero<<<(grad_input2->numel() + 512 - 1) / 512, 512, 0,
dev_ctx.stream()>>>(grad_input2->data<T>(),
grad_input2->numel());
auto grad_out_dims = grad_output->dims();
int GOC = grad_out_dims[1];
int GOH = grad_out_dims[2];
int GOW = grad_out_dims[3];
dim3 blocks_grid(N, H, W);
dim3 threads_block(THREADS_PER_BLOCK);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(
input1->data<T>(), rinput1.data<T>(), C, H, W, pad_size);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(
input2->data<T>(), rinput2.data<T>(), C, H, W, pad_size);
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(H, W, C);
for (int n = 0; n < N; n++) {
correlation_backward_input1<
T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(
n, grad_input1->data<T>(), C, H, W, grad_output->data<T>(), GOC, GOH,
GOW, rinput2.data<T>(), pad_size, kernel_size, max_displacement,
stride1, stride2);
}
for (int n = 0; n < N; n++) {
correlation_backward_input2<
T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(
n, grad_input2->data<T>(), C, H, W, grad_output->data<T>(), GOC, GOH,
GOW, rinput1.data<T>(), pad_size, kernel_size, max_displacement,
stride1, stride2);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>);
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License.*/ limitations under the License.*/
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -54,11 +55,14 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -54,11 +55,14 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
score_dim[1])); score_dim[1]));
} }
context->SetOutputDim("FpnRois", {post_nms_topN, 4}); context->SetOutputDim("FpnRois", {post_nms_topN, 4});
if (context->HasOutput("RoisNum")) {
context->SetOutputDim("RoisNum", {-1});
}
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
// in Kernel. // in Kernel.
context->ShareLoD("MultiLevelRois", "FpnRois"); context->ShareLoD("MultiLevelRois", "FpnRois");
} }
if (context->IsRuntime()) { if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs = std::vector<framework::InferShapeVarPtr> roi_inputs =
context->GetInputVarPtrs("MultiLevelRois"); context->GetInputVarPtrs("MultiLevelRois");
std::vector<framework::InferShapeVarPtr> score_inputs = std::vector<framework::InferShapeVarPtr> score_inputs =
...@@ -99,7 +103,16 @@ class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -99,7 +103,16 @@ class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) Multiple score LoDTensors from each level in shape" "(LoDTensor) Multiple score LoDTensors from each level in shape"
" (N, 1), N is the number of RoIs.") " (N, 1), N is the number of RoIs.")
.AsDuplicable(); .AsDuplicable();
AddInput(
"MultiLevelRoIsNum",
"(List of Tensor) The RoIs' number of each image on multiple levels."
"The number on each level has the shape of (N), N is the number of "
"images.")
.AsDuplicable()
.AsDispensable();
AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores"); AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores");
AddOutput("RoisNum", "(Tensor), Number of RoIs in each images.")
.AsDispensable();
AddAttr<int>("post_nms_topN", AddAttr<int>("post_nms_topN",
"Select post_nms_topN RoIs from" "Select post_nms_topN RoIs from"
" all images and all fpn layers"); " all images and all fpn layers");
...@@ -123,3 +136,14 @@ REGISTER_OPERATOR( ...@@ -123,3 +136,14 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals, REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
ops::CollectFpnProposalsOpKernel<float>, ops::CollectFpnProposalsOpKernel<float>,
ops::CollectFpnProposalsOpKernel<double>); ops::CollectFpnProposalsOpKernel<double>);
REGISTER_OP_VERSION(collect_fpn_proposals)
.AddCheckpoint(
R"ROC(
Upgrade collect_fpn_proposals add a new input
[MultiLevelRoIsNum] and add a new output [RoisNum].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("MultiLevelRoIsNum",
"The RoIs' number of each image on multiple levels."
"The number on each level has the shape of (N), "
"N is the number of images.")
.NewOutput("RoisNum", "The number of RoIs in each image."));
...@@ -80,16 +80,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -80,16 +80,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
int lod_size; int lod_size;
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()); auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
auto multi_rois_num = ctx.MultiInput<Tensor>("MultiLevelRoIsNum");
for (size_t i = 0; i < roi_ins.size(); ++i) { for (size_t i = 0; i < roi_ins.size(); ++i) {
auto roi_in = roi_ins[i]; auto roi_in = roi_ins[i];
auto score_in = score_ins[i]; auto score_in = score_ins[i];
auto roi_lod = roi_in->lod().back(); if (multi_rois_num.size() > 0) {
lod_size = roi_lod.size() - 1; framework::Tensor temp;
TensorCopySync(*multi_rois_num[i], platform::CPUPlace(), &temp);
const int* length_in = temp.data<int>();
lod_size = multi_rois_num[i]->numel();
for (size_t n = 0; n < lod_size; ++n) { for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) { for (size_t j = 0; j < length_in[n]; ++j) {
roi_batch_id_data[index++] = n; roi_batch_id_data[index++] = n;
} }
} }
} else {
auto length_in = roi_in->lod().back();
lod_size = length_in.size() - 1;
for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = length_in[n]; j < length_in[n + 1]; ++j) {
roi_batch_id_data[index++] = n;
}
}
}
memory::Copy(place, concat_rois_data + roi_offset, place, memory::Copy(place, concat_rois_data + roi_offset, place,
roi_in->data<T>(), roi_in->numel() * sizeof(T), roi_in->data<T>(), roi_in->numel() * sizeof(T),
...@@ -190,6 +203,13 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -190,6 +203,13 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
offset.emplace_back(offset.back() + length_lod_cpu[i]); offset.emplace_back(offset.back() + length_lod_cpu[i]);
} }
if (ctx.HasOutput("RoisNum")) {
auto* rois_num = ctx.Output<Tensor>("RoisNum");
int* rois_num_data = rois_num->mutable_data<int>({lod_size}, place);
memory::Copy(place, rois_num_data, place, length_lod_data,
lod_size * sizeof(int), dev_ctx.stream());
}
framework::LoD lod; framework::LoD lod;
lod.emplace_back(offset); lod.emplace_back(offset);
fpn_rois->set_lod(lod); fpn_rois->set_lod(lod);
......
...@@ -17,6 +17,7 @@ limitations under the License.*/ ...@@ -17,6 +17,7 @@ limitations under the License.*/
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <numeric>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -65,6 +66,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -65,6 +66,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
auto multi_layer_scores = auto multi_layer_scores =
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores"); context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores");
auto multi_rois_num = context.MultiInput<Tensor>("MultiLevelRoIsNum");
int num_size = multi_rois_num.size();
auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois"); auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois");
...@@ -88,11 +91,21 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -88,11 +91,21 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
const int num_fpn_level = multi_layer_rois.size(); const int num_fpn_level = multi_layer_rois.size();
std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0); std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0);
for (int i = 0; i < num_fpn_level; ++i) { for (int i = 0; i < num_fpn_level; ++i) {
int all_rois = 0;
if (num_size == 0) {
auto cur_rois_lod = multi_layer_rois[i]->lod().back(); auto cur_rois_lod = multi_layer_rois[i]->lod().back();
integral_of_all_rois[i + 1] = all_rois = cur_rois_lod[cur_rois_lod.size() - 1];
integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1]; } else {
const int* cur_rois_num = multi_rois_num[i]->data<int>();
all_rois = std::accumulate(
cur_rois_num, cur_rois_num + multi_rois_num[i]->numel(), 0);
}
integral_of_all_rois[i + 1] = integral_of_all_rois[i] + all_rois;
} }
const int batch_size = (num_size == 0)
? multi_layer_rois[0]->lod().back().size() - 1
: multi_rois_num[0]->numel();
// concatenate all fpn rois scores into a list // concatenate all fpn rois scores into a list
// create a vector to store all scores // create a vector to store all scores
std::vector<ScoreWithID<T>> scores_of_all_rois( std::vector<ScoreWithID<T>> scores_of_all_rois(
...@@ -100,12 +113,21 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -100,12 +113,21 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < num_fpn_level; ++i) { for (int i = 0; i < num_fpn_level; ++i) {
const T* cur_level_scores = multi_layer_scores[i]->data<T>(); const T* cur_level_scores = multi_layer_scores[i]->data<T>();
int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i]; int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i];
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
int cur_batch_id = 0; int cur_batch_id = 0;
int pre_num = 0;
for (int j = 0; j < cur_level_num; ++j) { for (int j = 0; j < cur_level_num; ++j) {
if (num_size == 0) {
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
if (static_cast<size_t>(j) >= cur_scores_lod[cur_batch_id + 1]) { if (static_cast<size_t>(j) >= cur_scores_lod[cur_batch_id + 1]) {
cur_batch_id++; cur_batch_id++;
} }
} else {
const int* rois_num_data = multi_rois_num[i]->data<int>();
if (j >= pre_num + rois_num_data[cur_batch_id]) {
pre_num += rois_num_data[cur_batch_id];
cur_batch_id++;
}
}
int cur_index = j + integral_of_all_rois[i]; int cur_index = j + integral_of_all_rois[i];
scores_of_all_rois[cur_index].score = cur_level_scores[j]; scores_of_all_rois[cur_index].score = cur_level_scores[j];
scores_of_all_rois[cur_index].index = j; scores_of_all_rois[cur_index].index = j;
...@@ -134,6 +156,9 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -134,6 +156,9 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
T* fpn_rois_data = fpn_rois->data<T>(); T* fpn_rois_data = fpn_rois->data<T>();
std::vector<size_t> lod0(1, 0); std::vector<size_t> lod0(1, 0);
int cur_batch_id = 0; int cur_batch_id = 0;
std::vector<int64_t> num_per_batch;
int pre_idx = 0;
int cur_num = 0;
for (int i = 0; i < post_nms_topN; ++i) { for (int i = 0; i < post_nms_topN; ++i) {
int cur_fpn_level = scores_of_all_rois[i].level; int cur_fpn_level = scores_of_all_rois[i].level;
int cur_level_index = scores_of_all_rois[i].index; int cur_level_index = scores_of_all_rois[i].index;
...@@ -144,6 +169,18 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -144,6 +169,18 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
if (scores_of_all_rois[i].batch_id != cur_batch_id) { if (scores_of_all_rois[i].batch_id != cur_batch_id) {
cur_batch_id = scores_of_all_rois[i].batch_id; cur_batch_id = scores_of_all_rois[i].batch_id;
lod0.emplace_back(i); lod0.emplace_back(i);
cur_num = i - pre_idx;
pre_idx = i;
num_per_batch.emplace_back(cur_num);
}
}
num_per_batch.emplace_back(post_nms_topN - pre_idx);
if (context.HasOutput("RoisNum")) {
auto* rois_num = context.Output<Tensor>("RoisNum");
int* rois_num_data =
rois_num->mutable_data<int>({batch_size}, context.GetPlace());
for (int i = 0; i < batch_size; i++) {
rois_num_data[i] = num_per_batch[i];
} }
} }
lod0.emplace_back(post_nms_topN); lod0.emplace_back(post_nms_topN);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" #include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -48,6 +49,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -48,6 +49,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputsDim("MultiFpnRois", outs_dims); ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {-1, 1}); ctx->SetOutputDim("RestoreIndex", {-1, 1});
if (ctx->HasOutputs("MultiLevelRoIsNum")) {
std::vector<framework::DDim> outs_num_dims;
for (size_t i = 0; i < num_out_rois; ++i) {
outs_num_dims.push_back({-1});
}
ctx->SetOutputsDim("MultiLevelRoIsNum", outs_num_dims);
}
if (!ctx->IsRuntime()) { if (!ctx->IsRuntime()) {
for (size_t i = 0; i < num_out_rois; ++i) { for (size_t i = 0; i < num_out_rois; ++i) {
ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i); ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i);
...@@ -66,12 +75,22 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -66,12 +75,22 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker { class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)"); AddInput("FpnRois", "(LoDTensor) The RoIs at all levels in shape (-1, 4)");
AddInput("RoisNum",
"(Tensor) The number of RoIs in shape (B),"
"B is the number of images")
.AsDispensable();
AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator") AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
.AsDuplicable(); .AsDuplicable();
AddOutput("RestoreIndex", AddOutput("RestoreIndex",
"(Tensor) An array of positive number which is " "(Tensor) An array of positive number which is "
"used to restore the order of FpnRois"); "used to restore the order of FpnRois");
AddOutput("MultiLevelRoIsNum",
"(List of Tensor) The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
"B is the number of images.")
.AsDuplicable()
.AsDispensable();
AddAttr<int>("min_level", AddAttr<int>("min_level",
"The lowest level of FPN layer where the" "The lowest level of FPN layer where the"
" proposals come from"); " proposals come from");
...@@ -105,3 +124,14 @@ REGISTER_OPERATOR( ...@@ -105,3 +124,14 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals, REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>, ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>); ops::DistributeFpnProposalsOpKernel<double>);
REGISTER_OP_VERSION(distribute_fpn_proposals)
.AddCheckpoint(
R"ROC(
Upgrade distribute_fpn_proposals add a new input
[RoisNum] and add a new output [MultiLevelRoIsNum].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("RoIsNum", "The number of RoIs in each image.")
.NewOutput("MultiLevelRoisNum",
"The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
"B is the number of images."));
...@@ -76,12 +76,20 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -76,12 +76,20 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
int num_level = max_level - min_level + 1; int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty // check that the fpn_rois is not empty
if (!ctx.HasInput("RoisNum")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fpn_rois->lod().size(), 1UL, fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD" platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level")); "with one level"));
}
auto fpn_rois_lod = fpn_rois->lod().back(); std::vector<size_t> fpn_rois_lod;
if (ctx.HasInput("RoisNum")) {
auto* rois_num = ctx.Input<Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
int lod_size = fpn_rois_lod.size() - 1; int lod_size = fpn_rois_lod.size() - 1;
int roi_num = fpn_rois_lod[lod_size]; int roi_num = fpn_rois_lod[lod_size];
...@@ -154,6 +162,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -154,6 +162,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
restore_idx_data, roi_num); restore_idx_data, roi_num);
int start = 0; int start = 0;
auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum");
for (int i = 0; i < num_level; ++i) { for (int i = 0; i < num_level; ++i) {
Tensor sub_lod = sub_lod_list.Slice(i, i + 1); Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
int* sub_lod_data = sub_lod.data<int>(); int* sub_lod_data = sub_lod.data<int>();
...@@ -180,6 +190,11 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -180,6 +190,11 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim}, multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace()); dev_ctx.GetPlace());
} }
if (multi_rois_num.size() > 0) {
Tensor* rois_num_t = multi_rois_num[i];
TensorCopySync(sub_lod, dev_ctx.GetPlace(), rois_num_t);
rois_num_t->Resize({lod_size});
}
framework::LoD lod; framework::LoD lod;
lod.emplace_back(offset); lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod); multi_fpn_rois[i]->set_lod(lod);
......
...@@ -28,6 +28,21 @@ namespace operators { ...@@ -28,6 +28,21 @@ namespace operators {
const int kBoxDim = 4; const int kBoxDim = 4;
inline std::vector<size_t> GetLodFromRoisNum(const Tensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
Tensor cpu_tensor;
if (platform::is_gpu_place(rois_num->place())) {
TensorCopySync(*rois_num, platform::CPUPlace(), &cpu_tensor);
rois_num_data = cpu_tensor.data<int>();
}
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
template <typename T> template <typename T>
static inline T BBoxArea(const T* box, bool normalized) { static inline T BBoxArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) { if (box[2] < box[0] || box[3] < box[1]) {
...@@ -65,13 +80,22 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -65,13 +80,22 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
const int num_level = max_level - min_level + 1; const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty // check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ( if (!context.HasInput("RoisNum")) {
fpn_rois->lod().size(), 1UL, PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD " platform::errors::InvalidArgument(
"DistributeFpnProposalsOp needs LoD "
"with one level.")); "with one level."));
}
auto fpn_rois_lod = fpn_rois->lod().back(); std::vector<size_t> fpn_rois_lod;
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1]; int fpn_rois_num;
if (context.HasInput("RoisNum")) {
auto* rois_num = context.Input<Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level; std::vector<int> target_level;
// std::vector<int> target_level(fpn_rois_num, -1); // std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level // record the number of rois in each level
...@@ -136,6 +160,18 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -136,6 +160,18 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < fpn_rois_num; ++i) { for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i; restore_index_data[restore_index_inter[i]] = i;
} }
auto multi_rois_num = context.MultiOutput<Tensor>("MultiLevelRoIsNum");
if (multi_rois_num.size() > 0) {
int batch_size = fpn_rois_lod.size() - 1;
for (int i = 0; i < num_level; ++i) {
int* rois_num_data = multi_rois_num[i]->mutable_data<int>(
{batch_size}, context.GetPlace());
for (int j = 0; j < batch_size; ++j) {
rois_num_data[j] = static_cast<int>(multi_fpn_rois_lod0[i][j + 1] -
multi_fpn_rois_lod0[i][j]);
}
}
}
// merge lod information into LoDTensor // merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) { for (int i = 0; i < num_level; ++i) {
framework::LoD lod; framework::LoD lod;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -61,6 +62,10 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -61,6 +62,10 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("RpnRois", {-1, 4}); ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1}); ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1));
ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1));
}
} }
protected: protected:
...@@ -347,7 +352,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -347,7 +352,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
lod0.push_back(0); lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4}); anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4});
std::vector<int64_t> tmp_lod; std::vector<int> tmp_num;
int64_t num_proposals = 0; int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) { for (int64_t i = 0; i < num; ++i) {
...@@ -369,16 +374,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -369,16 +374,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
AppendProposals(rpn_roi_probs, num_proposals, scores); AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0]; num_proposals += proposals.dims()[0];
lod0.push_back(num_proposals); lod0.push_back(num_proposals);
tmp_lod.push_back(num_proposals); tmp_num.push_back(proposals.dims()[0]);
} }
if (context.HasOutput("RpnRoisLod")) { if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_lod = context.Output<Tensor>("RpnRoisLod"); auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_lod->mutable_data<int64_t>({num}, context.GetPlace()); rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int64_t *lod_data = rpn_rois_lod->data<int64_t>(); int *num_data = rpn_rois_num->data<int>();
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
lod_data[i] = tmp_lod[i]; num_data[i] = tmp_num[i];
} }
rpn_rois_lod->Resize({num}); rpn_rois_num->Resize({num});
} }
rpn_rois->set_lod(lod); rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod); rpn_roi_probs->set_lod(lod);
...@@ -433,6 +438,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -433,6 +438,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
Tensor keep; Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep); FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
// Handle the case when there is no keep index left
if (keep.numel() == 0) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
bbox_sel.mutable_data<T>({1, 4}, ctx.GetPlace());
set_zero(ctx, &bbox_sel, static_cast<T>(0));
Tensor scores_filter;
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(bbox_sel, scores_filter);
}
Tensor scores_filter; Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace()); bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
...@@ -481,7 +496,8 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -481,7 +496,8 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor), Output proposals with shape (rois_num, 4)."); "(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs", AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1)."); "(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddOutput("RpnRoisLod", "(Tensor), rpn rois's lod info").AsDispensable(); AddOutput("RpnRoisNum", "(Tensor), The number of Rpn RoIs in each image")
.AsDispensable();
AddAttr<int>("pre_nms_topN", AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before " "Number of top scoring RPN proposals to keep before "
"applying NMS."); "applying NMS.");
...@@ -515,3 +531,11 @@ REGISTER_OPERATOR( ...@@ -515,3 +531,11 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>, REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
ops::GenerateProposalsKernel<double>); ops::GenerateProposalsKernel<double>);
REGISTER_OP_VERSION(generate_proposals)
.AddCheckpoint(
R"ROC(
Upgrade generate_proposals add a new output [RpnRoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"RpnRoisNum",
"The number of Rpn RoIs in each image. RpnRoisNum is "
"dispensable."));
...@@ -330,6 +330,15 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -330,6 +330,15 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
keep_index.Resize({keep_num}); keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter; Tensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &proposals_filter, static_cast<T>(0));
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace()); proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace()); scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter); GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
...@@ -421,7 +430,7 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -421,7 +430,7 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
int64_t num_proposals = 0; int64_t num_proposals = 0;
std::vector<size_t> offset(1, 0); std::vector<size_t> offset(1, 0);
std::vector<int64_t> tmp_lod; std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) { for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1); Tensor im_info_slice = im_info->Slice(i, i + 1);
...@@ -448,15 +457,15 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -448,15 +457,15 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
dev_ctx.Wait(); dev_ctx.Wait();
num_proposals += proposals.dims()[0]; num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals); offset.emplace_back(num_proposals);
tmp_lod.push_back(num_proposals); tmp_num.push_back(proposals.dims()[0]);
} }
if (context.HasOutput("RpnRoisLod")) { if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_lod = context.Output<Tensor>("RpnRoisLod"); auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_lod->mutable_data<int64_t>({num}, context.GetPlace()); rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int64_t *lod_data = rpn_rois_lod->data<int64_t>(); int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, lod_data, cpu_place, &tmp_lod[0], memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num,
sizeof(int64_t) * num, dev_ctx.stream()); dev_ctx.stream());
rpn_rois_lod->Resize({num}); rpn_rois_num->Resize({num});
} }
framework::LoD lod; framework::LoD lod;
lod.emplace_back(offset); lod.emplace_back(offset);
......
...@@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>( AddAttr<std::string>(
"padding_mode", "padding_mode",
"(bool, default true) The padding method used when source" "(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and " "index is out of input images. It can be 'zeros', 'reflection' and "
"'border'.") "'border'.")
.SetDefault("zeros"); .SetDefault("zeros");
...@@ -174,6 +174,10 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -174,6 +174,10 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid"); auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
......
...@@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> { ...@@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
Mode mode; Mode mode;
if (padding_mode_s == "border") { if (padding_mode_s == "border") {
padding_mode = PaddingMode::border; padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") { } else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect; padding_mode = PaddingMode::reflect;
} else { } else {
padding_mode = PaddingMode::zeros; padding_mode = PaddingMode::zeros;
...@@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
Mode mode; Mode mode;
if (padding_mode_s == "border") { if (padding_mode_s == "border") {
padding_mode = PaddingMode::border; padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") { } else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect; padding_mode = PaddingMode::reflect;
} else { } else {
padding_mode = PaddingMode::zeros; padding_mode = PaddingMode::zeros;
......
...@@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx, ...@@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx,
if (padding_mode == "border") { if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0)) grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val)); .cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflect") { } else if (padding_mode == "reflection") {
if (align_corners) { if (align_corners) {
auto double_range = static_cast<T>(max_val * 2); auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs(); auto grid_abs = grid_slice_t.abs();
...@@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx, ...@@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
auto in_bound = (res == grid_slice_t); auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>(); grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res; grid_slice_t.device(place) = res;
} else if (padding_mode == "reflect") { } else if (padding_mode == "reflection") {
if (align_corners) { if (align_corners) {
auto double_range = static_cast<T>(max_val * 2); auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0)); auto is_neg = (grid_slice_t < static_cast<T>(0));
......
...@@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> { ...@@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> {
loss_t.device(place) = output; loss_t.device(place) = output;
} else if ("batchmean" == reduction) { } else if ("batchmean" == reduction) {
auto output_sum = output.sum(); auto output_sum = output.sum();
if (n > 0) {
loss_t.device(place) = output_sum / output_sum.constant(n); loss_t.device(place) = output_sum / output_sum.constant(n);
} else {
loss_t.device(place) = output_sum;
}
} else if ("mean" == reduction) { } else if ("mean" == reduction) {
loss_t.device(place) = output.mean(); loss_t.device(place) = output.mean();
} else if ("sum" == reduction) { } else if ("sum" == reduction) {
......
...@@ -29,11 +29,24 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> { ...@@ -29,11 +29,24 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst->dims(); auto dst_dims = dst->dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
"The src must be matrix with rank 2."); platform::errors::InvalidArgument(
"The source tensor must be a matrix with rank 2, but "
"got the source tensor rank is %lu. "
"Please check the rank of the source tensor",
src_dims.size()));
PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL,
"The dst must be matrix with rank 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], "The destination tensor must be a matrix with rank, "
"The width of src and dst must be same."); "but got the destination tensor rank is %lu. "
"Please check the rank of the destination tensor",
dst_dims.size()));
PADDLE_ENFORCE_EQ(
src_dims[1], dst_dims[1],
platform::errors::InvalidArgument(
"The width of the source tensor and the destination tensor must be "
"same. But got %lu != %lu.Please check the rank of the source "
"tensor",
src_dims.size(), dst_dims.size()));
auto height = dst_dims[0]; auto height = dst_dims[0];
auto width = dst_dims[1]; auto width = dst_dims[1];
auto* src_data = src.data<T>(); auto* src_data = src.data<T>();
......
...@@ -46,11 +46,24 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { ...@@ -46,11 +46,24 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst->dims(); auto dst_dims = dst->dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(), 2,
"The src must be matrix with rank 2."); platform::errors::InvalidArgument(
"The source tensor must be a matrix with rank 2, but "
"got the source tensor rank is %lu. "
"Please check the rank of the source tensor",
src_dims.size()));
PADDLE_ENFORCE_EQ(dst_dims.size(), 2, PADDLE_ENFORCE_EQ(dst_dims.size(), 2,
"The dst must be matrix with rank 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], "The destination tensor must be a matrix with rank, "
"The width of src and dst must be same."); "but got the destination tensor rank is %lu. "
"Please check the rank of the destination tensor",
dst_dims.size()));
PADDLE_ENFORCE_EQ(
src_dims[1], dst_dims[1],
platform::errors::InvalidArgument(
"The width of the source tensor and the destination tensor must be "
"same. But got %lu != %lu.Please check the rank of the source "
"tensor",
src_dims.size(), dst_dims.size()));
auto height = dst_dims[0]; auto height = dst_dims[0];
auto width = dst_dims[1]; auto width = dst_dims[1];
auto* src_data = src.data<T>(); auto* src_data = src.data<T>();
......
...@@ -64,19 +64,30 @@ class LoDTensor2BatchFunctor { ...@@ -64,19 +64,30 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const { bool is_reverse = false) const {
if (!is_cal_batch_lod) { if (!is_cal_batch_lod) {
auto lods = batch->lod(); auto lods = batch->lod();
PADDLE_ENFORCE_GT(lods.size(), 2UL, PADDLE_ENFORCE_GT(
lods.size(), 2UL,
platform::errors::InvalidArgument(
"The LoD of LoDTensor should inlcude at least 2-level " "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information."); "sequence information, but got the LoD level is %lu. Please "
"check the input value.",
lods.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0]), lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0]),
"The LoD information should be consistent with the dims."); platform::errors::InvalidArgument(
"The LoD information should be consistent with the dims, but got "
"%lu != %lu. Please check the input value.",
lods[1].size(), static_cast<size_t>(lod_tensor.dims()[0])));
CopyMatrixRowsFunctor<DeviceContext, T> to_batch; CopyMatrixRowsFunctor<DeviceContext, T> to_batch;
to_batch(context, lod_tensor, lods[1], batch, true); to_batch(context, lod_tensor, lods[1], batch, true);
return; return;
} }
auto lods = lod_tensor.lod(); auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lods.size(), 1UL,
platform::errors::InvalidArgument(
"Only support one level sequence now, but got the "
"LoD level is %lu. Please check the input value.",
lods.size()));
const auto& lod = lods[0]; const auto& lod = lods[0];
...@@ -161,12 +172,19 @@ class Batch2LoDTensorFunctor { ...@@ -161,12 +172,19 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch, const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const { framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod(); auto in_lod = batch.lod();
PADDLE_ENFORCE_GT(in_lod.size(), 2UL, PADDLE_ENFORCE_GT(
in_lod.size(), 2UL,
platform::errors::InvalidArgument(
"The LoD of LoDTensor should inlcude at least 2-level " "The LoD of LoDTensor should inlcude at least 2-level "
"sequence information."); "sequence information, but got the LoD level is %lu. Please check "
"the input value.",
in_lod.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]), in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]),
"The LoD information should be consistent with the dims."); platform::errors::InvalidArgument(
"The LoD information should be consistent with the dims, but got "
"%lu != %lu. Please check the input value.",
in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0])));
CopyMatrixRowsFunctor<DeviceContext, T> to_seq; CopyMatrixRowsFunctor<DeviceContext, T> to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false); to_seq(context, batch, in_lod[1], lod_tensor, false);
} }
......
...@@ -35,7 +35,11 @@ void CopyValidData(framework::Tensor* dst_tensor, ...@@ -35,7 +35,11 @@ void CopyValidData(framework::Tensor* dst_tensor,
int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
pad_seq_len, valid_seq_len, pad_seq_len, valid_seq_len,
"The padded sequence length can not be less than its original length."); platform::errors::InvalidArgument(
"The padded sequence length can not "
"be less than its original length. Expected %ld >= %ld, but got "
"%ld < %ld. Please check input value.",
pad_seq_len, valid_seq_len, pad_seq_len, valid_seq_len));
int seq_data_offset = seq_offsets[seq_idx] * step_width; int seq_data_offset = seq_offsets[seq_idx] * step_width;
int pad_data_offset = layout == kBatchLengthWidth int pad_data_offset = layout == kBatchLengthWidth
? seq_idx * pad_seq_len * step_width ? seq_idx * pad_seq_len * step_width
...@@ -95,9 +99,14 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> { ...@@ -95,9 +99,14 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout); step_width, layout);
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
PADDLE_ENFORCE_EQ(
pad_value.numel() == 1 || pad_value.numel() == step_width, true,
platform::errors::InvalidArgument(
"The numel of 'pad_value' can only be 1 or be equal to the " "The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'."); "'step_width', but got %ld != 1 and %ld. Please check the input "
"value.",
pad_value.numel(), step_width));
// fill padding value // fill padding value
T* pad_data = pad_tensor->data<T>(); T* pad_data = pad_tensor->data<T>();
......
...@@ -66,17 +66,25 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -66,17 +66,25 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (pad_seq_len == -1) { if (pad_seq_len == -1) {
pad_seq_len = max_seq_len; pad_seq_len = max_seq_len;
} }
PADDLE_ENFORCE_GE(pad_seq_len, max_seq_len, PADDLE_ENFORCE_GE(
pad_seq_len, max_seq_len,
platform::errors::InvalidArgument(
"The pad_seq_len must be equal to or greater than the " "The pad_seq_len must be equal to or greater than the "
"original max sequence length."); "original max sequence length. Expected %ld >= %ld, but got %ld < "
"%ld. Please check the input value.",
pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
int step_width = seq_tensor.numel() / seq_tensor_dims[0]; int step_width = seq_tensor.numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1; int seq_num = seq_offsets.size() - 1;
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout); step_width, layout);
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, PADDLE_ENFORCE_EQ(
"The numel of 'pad_value' can only be 1 or be equal to the " pad_value.numel() == 1 || pad_value.numel() == step_width, true,
"'step_width'."); platform::errors::InvalidArgument(
"The numel of 'pad_value' can only be 1 or be equal to "
"the 'step_width', but got %ld != 1 and %ld. Please check the "
"input value.",
pad_value.numel(), step_width));
const int kBlockSize = 512; const int kBlockSize = 512;
......
...@@ -52,14 +52,25 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, ...@@ -52,14 +52,25 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
const framework::Vector<size_t>& seq_offset, const framework::Vector<size_t>& seq_offset,
int64_t padded_seq_len, int64_t step_width, int64_t padded_seq_len, int64_t step_width,
const PadLayout& layout) { const PadLayout& layout) {
PADDLE_ENFORCE_EQ(static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back(), PADDLE_ENFORCE_EQ(
static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back(),
platform::errors::InvalidArgument(
"Value of 1st dimension of the sequence tensor should be " "Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences."); "equal to sum of lengths of all sequences. Expected %ld == %ld, but "
"got %ld != %ld. Please check the input value.",
static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back(),
static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back()));
PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || PADDLE_ENFORCE_EQ(
seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
seq_tensor_dims.size() == pad_tensor_dims.size(), seq_tensor_dims.size() == pad_tensor_dims.size(),
true, platform::errors::InvalidArgument(
"pad_tensor's rank should be 1 greater than seq_tensor's " "pad_tensor's rank should be 1 greater than seq_tensor's "
"rank, or be equal with it."); "rank, or be equal with it. The pad_tensor's rank is %ld, "
"expected the seq_tensor's rank is %ld or %ld, but got %ld. "
"Please check the input value.",
pad_tensor_dims.size(), pad_tensor_dims.size(),
pad_tensor_dims.size() - 1, seq_tensor_dims.size()));
} }
/* /*
......
...@@ -42,15 +42,29 @@ class MaxSeqPoolFunctor { ...@@ -42,15 +42,29 @@ class MaxSeqPoolFunctor {
auto out_dims = output->dims(); auto out_dims = output->dims();
auto idx_dims = index->dims(); auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1, PADDLE_ENFORCE_GT(in_dims.size(), 1,
"The rank of input shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of input shall be greater than 1, but got "
"the rank is %ld. Please check the input value",
in_dims.size()));
PADDLE_ENFORCE_GT(out_dims.size(), 1, PADDLE_ENFORCE_GT(out_dims.size(), 1,
"The rank of output shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of output shall be greater than 1, but got "
"the rank is %ld. Please check the input value",
out_dims.size()));
for (int64_t i = 1; i < in_dims.size(); ++i) { for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i], PADDLE_ENFORCE_EQ(
"The dimension of input and output shall be same."); in_dims[i], out_dims[i],
platform::errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld "
"== %ld, but got %ld != %ld. Please check the input value.",
in_dims[i], out_dims[i], in_dims[i], out_dims[i]));
} }
PADDLE_ENFORCE_EQ(idx_dims, out_dims, PADDLE_ENFORCE_EQ(
"The dimension of index and output shall be same."); idx_dims, out_dims,
platform::errors::InvalidArgument(
"The dimension of index and output shall be same. Expected %ld == "
"%ld, but got %ld != %ld. Please check the input value.",
idx_dims, out_dims, idx_dims, out_dims));
auto lod_level = input.lod().size(); auto lod_level = input.lod().size();
auto starts = input.lod()[lod_level - 1]; auto starts = input.lod()[lod_level - 1];
...@@ -94,12 +108,22 @@ class MaxSeqPoolFunctor<T, true> { ...@@ -94,12 +108,22 @@ class MaxSeqPoolFunctor<T, true> {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto out_dims = output->dims(); auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1, PADDLE_ENFORCE_GT(in_dims.size(), 1,
"The rank of input shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of input shall be greater than 1, but got "
"%ld <= 1. Please check the input value.",
in_dims.size()));
PADDLE_ENFORCE_GT(out_dims.size(), 1, PADDLE_ENFORCE_GT(out_dims.size(), 1,
"The rank of output shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of output shall be greater than 1, but got "
"%ld <= 1. Please check the input value.",
out_dims.size()));
for (int64_t i = 1; i < in_dims.size(); ++i) { for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i], PADDLE_ENFORCE_EQ(
"The dimension of input and output shall be same."); in_dims[i], out_dims[i],
platform::errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld "
"== %ld, but got %ld != %ld. Please check the input value.",
in_dims[i], out_dims[i], in_dims[i], out_dims[i]));
} }
auto lod_level = input.lod().size(); auto lod_level = input.lod().size();
...@@ -139,16 +163,29 @@ class MaxSeqPoolGradFunctor { ...@@ -139,16 +163,29 @@ class MaxSeqPoolGradFunctor {
auto ig_dims = in_grad->dims(); auto ig_dims = in_grad->dims();
auto idx_dims = index.dims(); auto idx_dims = index.dims();
PADDLE_ENFORCE_GT(og_dims.size(), 1, PADDLE_ENFORCE_GT(og_dims.size(), 1,
"The rank of output@Grad shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of output@Grad shall be greater than 1, "
"but got %ld <= 1. Please check the input value.",
og_dims.size()));
PADDLE_ENFORCE_GT(ig_dims.size(), 1, PADDLE_ENFORCE_GT(ig_dims.size(), 1,
"The rank of input@Grad shall be greater than 1."); platform::errors::InvalidArgument(
"The rank of input@Grad shall be greater than 1, but "
"got %ld <= 1. Please check the input value.",
ig_dims.size()));
for (int64_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i],
og_dims[i], ig_dims[i], platform::errors::InvalidArgument(
"The dimension of input@Grad and output@Grad shall be same."); "The dimension of input@Grad and output@Grad shall "
"be same. Expected %ld == %ld, but got %ld != %ld. "
"Please check the input value.",
og_dims[i], ig_dims[i], og_dims[i], ig_dims[i]));
} }
PADDLE_ENFORCE_EQ(idx_dims, og_dims, PADDLE_ENFORCE_EQ(
"The dimension of index and output@Grad shall be same."); idx_dims, og_dims,
platform::errors::InvalidArgument(
"The dimension of index and output@Grad shall be same. Expected "
"%ld == %ld, but got %ld != %ld. Please check the input value.",
idx_dims, og_dims, idx_dims, og_dims));
const T* og_data = out_grad.data<T>(); const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>(); const int* max_index = index.data<int>();
...@@ -244,9 +281,12 @@ class SumSeqPoolGradFunctor { ...@@ -244,9 +281,12 @@ class SumSeqPoolGradFunctor {
auto lod = in_grad->lod()[lod_level - 1]; auto lod = in_grad->lod()[lod_level - 1];
int64_t out_w = out_grad.numel() / out_grad.dims()[0]; int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0]; int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(in_w, out_w,
in_w, out_w, platform::errors::InvalidArgument(
"The feature size of input@Grad and output@Grad shall be same."); "The feature size of input@Grad and output@Grad "
"shall be same. Expected %ld == %ld, but got %ld != "
"%ld. Please check the input value.",
in_w, out_w, in_w, out_w));
const T* out_g_data = out_grad.data<T>(); const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace()); T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
...@@ -298,7 +338,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -298,7 +338,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
auto place = context.GetPlace(); auto place = context.GetPlace();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place), true, platform::is_cpu_place(place), true,
"Sequence_pool should run on CPU Device when pooltype is SUM"); platform::errors::InvalidArgument(
"Sequence_pool should run on CPU Device when pooltype is SUM"));
const T* src = input.data<T>(); const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place); T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr( jit::seq_pool_attr_t attr(
...@@ -342,7 +383,10 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -342,7 +383,10 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h)); std::sqrt(static_cast<T>(h));
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"AVERAGE\" and "
"\"SQRT\"",
pooltype));
} }
} }
} }
...@@ -400,7 +444,10 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -400,7 +444,10 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
} else if (pooltype == "FIRST") { } else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v; in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"AVERAGE\", "
"\"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
} }
} }
} }
......
...@@ -205,7 +205,10 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> { ...@@ -205,7 +205,10 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
lod.CUDAData(context.GetPlace()), lod.size(), item_dim, lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr); output->mutable_data<T>(context.GetPlace()), nullptr);
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"MAX\", "
"\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
} }
} }
}; };
...@@ -370,7 +373,10 @@ class SequencePoolGradFunctor<platform::CUDADeviceContext, T> { ...@@ -370,7 +373,10 @@ class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
in_grad->mutable_data<T>(context.GetPlace()), nullptr); in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported pooling pooltype: %s. Only support \"MAX\", "
"\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
pooltype));
} }
} }
}; };
......
...@@ -50,9 +50,21 @@ void TestSequencePoolingSum(const DeviceContext &context, ...@@ -50,9 +50,21 @@ void TestSequencePoolingSum(const DeviceContext &context,
in_grad.mutable_data<T>(in_dims, place); in_grad.mutable_data<T>(in_dims, place);
// check tensor contruction result // check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size()); PADDLE_ENFORCE_EQ(
in_grad.dims().size(), out_grad.dims().size(),
paddle::platform::errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld == "
"%ld, but got %ld != %ld. Please check the input value.",
in_grad.dims().size(), out_grad.dims().size(), in_grad.dims().size(),
out_grad.dims().size()));
for (int64_t i = 1; i < out_grad.dims().size(); ++i) { for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]); PADDLE_ENFORCE_EQ(
in_grad.dims()[i], out_grad.dims()[i],
paddle::platform::errors::InvalidArgument(
"The dimension of input and output shall be same. Expected %ld == "
"%ld, but got %ld != %ld. Please check the input value.",
in_grad.dims()[i], out_grad.dims()[i], in_grad.dims()[i],
out_grad.dims()[i]));
} }
// call functor // call functor
......
...@@ -55,7 +55,11 @@ void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet, ...@@ -55,7 +55,11 @@ void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr, std::vector<std::vector<int>> *tr,
size_t *node_count) { size_t *node_count) {
auto edge_set_dims = EdgeSet.dims(); auto edge_set_dims = EdgeSet.dims();
PADDLE_ENFORCE_EQ(edge_set_dims[1], 2); PADDLE_ENFORCE_EQ(edge_set_dims[1], 2,
platform::errors::InvalidArgument(
"The second dimension of the EdgeSet shall be 2, but "
"got %ld != 2. Please check the input value.",
edge_set_dims[1]));
int64_t edge_count = EdgeSet.numel(); int64_t edge_count = EdgeSet.numel();
const int *edge_data = EdgeSet.data<int>(); const int *edge_data = EdgeSet.data<int>();
......
...@@ -37,7 +37,13 @@ class Unpool2dMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -37,7 +37,13 @@ class Unpool2dMaxFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) { for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i]; int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor height * output tensor "
"width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
output_data[index] = input_data[i]; output_data[index] = input_data[i];
} }
input_data += input_feasize; input_data += input_feasize;
...@@ -72,7 +78,13 @@ class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, T> { ...@@ -72,7 +78,13 @@ class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) { for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i]; int index = indices_data[i];
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor height * output tensor "
"width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
input_grad_data[i] = output_grad_data[index]; input_grad_data[i] = output_grad_data[index];
} }
input_grad_data += input_feasize; input_grad_data += input_feasize;
......
...@@ -105,6 +105,12 @@ class PnormOp : public framework::OperatorWithKernel { ...@@ -105,6 +105,12 @@ class PnormOp : public framework::OperatorWithKernel {
bool asvector = ctx->Attrs().Get<bool>("asvector"); bool asvector = ctx->Attrs().Get<bool>("asvector");
if (asvector) { if (asvector) {
reduce_dims.emplace_back(1); reduce_dims.emplace_back(1);
if (keepdim) {
for (int i = 1; i < x_dim.size(); ++i) {
reduce_dims.emplace_back(1);
}
x_dim = framework::make_ddim(reduce_dims);
}
} else { } else {
if (axis < 0) axis = x_dim.size() + axis; if (axis < 0) axis = x_dim.size() + axis;
for (int i = 0; i < x_dim.size(); ++i) { for (int i = 0; i < x_dim.size(); ++i) {
......
...@@ -51,6 +51,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -51,6 +51,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T>
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
op->SetType("reduce_sum");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public: public:
...@@ -77,6 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ...@@ -77,6 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>); ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumGradNoNeedBufferVarInferer); ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h" #include "paddle/fluid/operators/roi_align_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,13 +36,13 @@ class ROIAlignOp : public framework::OperatorWithKernel { ...@@ -35,13 +36,13 @@ class ROIAlignOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs"); auto rois_dims = ctx->GetInputDim("ROIs");
if (ctx->HasInput("RoisLod")) { if (ctx->HasInput("RoisNum")) {
auto rois_lod_dims = ctx->GetInputDim("RoisLod"); auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_lod_dims.size(), 1, rois_num_dims.size(), 1,
platform::errors::InvalidArgument("The RoisLod dimension should be 1" platform::errors::InvalidArgument("The size of RoisNum should be 1"
", but got dimension = %d", ", but received size = %d",
rois_lod_dims.size())); rois_num_dims.size()));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_dims.size(), 4, input_dims.size(), 4,
...@@ -145,9 +146,9 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -145,9 +146,9 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"given as [[x1, y1, x2, y2], ...]. " "given as [[x1, y1, x2, y2], ...]. "
"(x1, y1) is the top left coordinates, and " "(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates."); "(x2, y2) is the bottom right coordinates.");
AddInput("RoisLod", AddInput("RoisNum",
"(Tensor), " "(Tensor), "
"The lod info of rois.") "The number of RoIs in each image.")
.AsDispensable(); .AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor), " "(Tensor), "
...@@ -203,7 +204,7 @@ class ROIAlignGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -203,7 +204,7 @@ class ROIAlignGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("roi_align_grad"); op->SetType("roi_align_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("RoisLod", this->Input("RoisLod")); op->SetInput("RoisNum", this->Input("RoisNum"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
...@@ -231,3 +232,10 @@ REGISTER_OP_CPU_KERNEL( ...@@ -231,3 +232,10 @@ REGISTER_OP_CPU_KERNEL(
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>, ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, int>); ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_align)
.AddCheckpoint(
R"ROC(
Upgrade roi_align add a new input [RoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"RoisNum",
"The number of RoIs in each image. RoisNum is dispensable."));
...@@ -257,24 +257,26 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -257,24 +257,26 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace); int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_lod->numel(); int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size, rois_batch_size, batch_size,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The rois_batch_size and imgs " "The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, " "batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d", "batch_size = %d",
rois_batch_size, batch_size)); rois_batch_size, batch_size));
std::vector<int64_t> rois_lod_(rois_batch_size); std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(), memory::Copy(cplace, rois_num_list.data(), gplace,
sizeof(int64_t) * rois_batch_size, 0); rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_list[n];
} }
} else { } else {
auto lod = rois->lod(); auto lod = rois->lod();
...@@ -348,16 +350,18 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -348,16 +350,18 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_lod->numel(); int rois_batch_size = rois_num_t->numel();
std::vector<int64_t> rois_lod_(rois_batch_size); std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(), memory::Copy(cplace, rois_num_list.data(), gplace,
sizeof(int64_t) * rois_batch_size, 0); rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_list[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
......
...@@ -165,21 +165,23 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -165,21 +165,23 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size; int rois_batch_size;
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_lod_t->numel(); rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size, rois_batch_size, batch_size,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images " "The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, " " must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d", "and the batch size of images is %d",
rois_batch_size, batch_size)); rois_batch_size, batch_size));
auto* rois_lod = rois_lod_t->data<int64_t>(); auto* rois_num_data = rois_num_t->data<int>();
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_data[n];
} }
} else { } else {
auto lod = rois->lod(); auto lod = rois->lod();
...@@ -303,14 +305,16 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -303,14 +305,16 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size; int rois_batch_size;
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_lod_t->numel(); rois_batch_size = rois_num_t->numel();
auto* rois_lod = rois_lod_t->data<int64_t>(); auto* rois_num_data = rois_num_t->data<int>();
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_data[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h" #include "paddle/fluid/operators/roi_pool_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,12 +35,13 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -34,12 +35,13 @@ class ROIPoolOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs"); auto rois_dims = ctx->GetInputDim("ROIs");
if (ctx->HasInput("RoisLod")) { if (ctx->HasInput("RoisNum")) {
auto rois_lod_dims = ctx->GetInputDim("RoisLod"); auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_lod_dims.size(), 1, PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The lod information tensor of ROIs should " "The second dimension of RoisNum should "
"be one-dimensional")); "be 1, but received dimension is %d",
rois_num_dims.size()));
} }
PADDLE_ENFORCE_EQ(input_dims.size(), 4, PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -140,7 +142,8 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,7 +142,8 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"Where batch_id is the id of the data, " "Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and " "(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates."); "(x2, y2) is the bottom right coordinates.");
AddInput("RoisLod", "(Tensor), The lod info of rois.").AsDispensable(); AddInput("RoisNum", "(Tensor), The number of RoIs in each image.")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor), " "(Tensor), "
"The output of ROIPoolOp is a 4-D tensor with shape " "The output of ROIPoolOp is a 4-D tensor with shape "
...@@ -197,7 +200,7 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -197,7 +200,7 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("roi_pool_grad"); op->SetType("roi_pool_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("RoisLod", this->Input("RoisLod")); op->SetInput("RoisNum", this->Input("RoisNum"));
op->SetInput("Argmax", this->Output("Argmax")); op->SetInput("Argmax", this->Output("Argmax"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
...@@ -223,3 +226,10 @@ REGISTER_OP_CPU_KERNEL( ...@@ -223,3 +226,10 @@ REGISTER_OP_CPU_KERNEL(
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>, ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>); ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_pool)
.AddCheckpoint(
R"ROC(
Upgrade roi_pool add a new input [RoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"RoisNum",
"The number of RoIs in each image. RoisNum is dispensable."));
...@@ -157,19 +157,21 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -157,19 +157,21 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace); int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_lod->numel(); int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size, rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same."); "The rois_batch_size and imgs batch_size must be the same.");
std::vector<int64_t> rois_lod_(rois_batch_size); std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(), memory::Copy(cplace, rois_num_list.data(), gplace,
sizeof(int64_t) * rois_batch_size, 0); rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_list[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
...@@ -206,7 +208,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -206,7 +208,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X"); auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs"); auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* rois_lod = ctx.Input<Tensor>("RoisLod"); auto* rois_lod = ctx.Input<Tensor>("RoisNum");
auto* argmax = ctx.Input<Tensor>("Argmax"); auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -229,17 +231,18 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -229,17 +231,18 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_lod->numel(); int rois_batch_size = rois_num_t->numel();
std::vector<int64_t> rois_lod_(rois_batch_size); std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, memory::Copy(cplace, rois_num_list.data(), gplace,
rois_lod->data<int64_t>(), rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
sizeof(int64_t) * rois_batch_size, 0); int start = 0;
for (int n = 0; n < rois_batch_size - 1; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) { for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_list[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
......
...@@ -58,18 +58,20 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -58,18 +58,20 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size; int rois_batch_size;
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_lod_t->numel(); rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size, rois_batch_size, batch_size,
platform::errors::InvalidArgument("The rois_batch_size and imgs " platform::errors::InvalidArgument("The rois_batch_size and imgs "
"batch_size must be the same.")); "batch_size must be the same."));
auto* rois_lod = rois_lod_t->data<int64_t>(); auto* rois_num_data = rois_num_t->data<int>();
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_data[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
...@@ -185,14 +187,16 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -185,14 +187,16 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size; int rois_batch_size;
if (ctx.HasInput("RoisLod")) { if (ctx.HasInput("RoisNum")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod"); auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_lod_t->numel(); rois_batch_size = rois_num_t->numel();
auto* rois_lod = rois_lod_t->data<int64_t>(); auto* rois_num_data = rois_num_t->data<int>();
for (int n = 0; n < rois_batch_size - 1; ++n) { int start = 0;
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
start += rois_num_data[n];
} }
} else { } else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
......
...@@ -43,6 +43,11 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -43,6 +43,11 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"nll_loss", {"X", "Label", "Weight"}}, {"nll_loss", {"X", "Label", "Weight"}},
{"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}},
{"gather", {"X", "Index", "Axis"}}, {"gather", {"X", "Index", "Axis"}},
{"roi_pool", {"X", "ROIs", "RoisNum"}},
{"roi_align", {"X", "ROIs", "RoisNum"}},
{"collect_fpn_proposals",
{"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
{"distribute_fpn_proposals", {"FpnRois", "RoisNum"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
...@@ -63,6 +68,10 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -63,6 +68,10 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}}, "ReserveSpace"}},
{"unique", {"Out", "Index", "Indices", "Counts"}}, {"unique", {"Out", "Index", "Indices", "Counts"}},
{"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"collect_fpn_proposals", {"FpnRois", "RoisNum"}},
{"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
...@@ -111,6 +111,8 @@ goto:success ...@@ -111,6 +111,8 @@ goto:success
:CASE_wincheck_openblas :CASE_wincheck_openblas
set WITH_MKL=OFF set WITH_MKL=OFF
set WITH_GPU=ON set WITH_GPU=ON
rem Temporarily turn off WITH_INFERENCE_API_TEST on GPU due to compile hang
set WITH_INFERENCE_API_TEST=OFF
call :cmake || goto cmake_error call :cmake || goto cmake_error
call :build || goto build_error call :build || goto build_error
call :test_whl_pacakage || goto test_whl_pacakage_error call :test_whl_pacakage || goto test_whl_pacakage_error
...@@ -242,7 +244,7 @@ dir %THIRD_PARTY_PATH:/=\%\install\mkldnn\bin ...@@ -242,7 +244,7 @@ dir %THIRD_PARTY_PATH:/=\%\install\mkldnn\bin
dir %THIRD_PARTY_PATH:/=\%\install\warpctc\bin dir %THIRD_PARTY_PATH:/=\%\install\warpctc\bin
set PATH=%THIRD_PARTY_PATH:/=\%\install\openblas\lib;%THIRD_PARTY_PATH:/=\%\install\openblas\bin;%THIRD_PARTY_PATH:/=\%\install\zlib\bin;%THIRD_PARTY_PATH:/=\%\install\mklml\lib;%THIRD_PARTY_PATH:/=\%\install\mkldnn\bin;%THIRD_PARTY_PATH:/=\%\install\warpctc\bin;%PATH% set PATH=%THIRD_PARTY_PATH:/=\%\install\openblas\lib;%THIRD_PARTY_PATH:/=\%\install\openblas\bin;%THIRD_PARTY_PATH:/=\%\install\zlib\bin;%THIRD_PARTY_PATH:/=\%\install\mklml\lib;%THIRD_PARTY_PATH:/=\%\install\mkldnn\bin;%THIRD_PARTY_PATH:/=\%\install\warpctc\bin;%PATH%
ctest.exe --output-on-failure -C Release -j 8 --repeat until-pass:4 ctest.exe --output-on-failure -C Release -j 8 --repeat until-pass:4 after-timeout:4
goto:eof goto:eof
:unit_test_error :unit_test_error
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
from paddle import fluid from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import core
import subprocess
import re
import platform
class ParameterServerOptimizer(MetaOptimizerBase): class ParameterServerOptimizer(MetaOptimizerBase):
...@@ -28,6 +32,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -28,6 +32,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
def _can_apply(self): def _can_apply(self):
if self.role_maker._is_collective: if self.role_maker._is_collective:
return False return False
if self.user_defined_strategy.auto == True:
return True
k_steps = self.user_defined_strategy.a_sync_configs["k_steps"] k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
return True if k_steps >= 0 else False return True if k_steps >= 0 else False
...@@ -127,6 +134,105 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -127,6 +134,105 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return _main, _startup return _main, _startup
def _try_auto_apply_geo(self, program, compiled_config):
def get_sys_free_mem():
plat = platform.system()
if platform.system() == "Darwin":
vm = subprocess.Popen(
['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
# Process vm_stat
vmLines = vm.split('\n')
sep = re.compile(':[\s]+')
vmStats = {}
for row in range(1, len(vmLines) - 2):
rowText = vmLines[row].strip()
rowElements = sep.split(rowText)
vmStats[(rowElements[0]
)] = int(rowElements[1].strip('\.')) * 4096
return vmStats["Pages free"]
elif platform.system() == "Linux":
mems = {}
with open('/proc/meminfo', 'rb') as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
free = mems[b'MemFree:']
return free
else:
raise ValueError(
"%s platform is unsupported is parameter server optimizer" %
(platform.system()))
if self.user_defined_strategy.auto == False:
return
a_sync_configs = self.user_defined_strategy.a_sync_configs
if a_sync_configs["k_steps"] >= 0:
return
self.user_defined_strategy.a_sync = True
if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
# auto async
a_sync_configs["k_steps"] = 0
self.user_defined_strategy.a_sync_configs = a_sync_configs
return
from paddle.fluid.incubate.fleet.parameter_server.ir.vars_metatools import dtype_to_size
free = get_sys_free_mem()
param_grad_pairs = compiled_config.origin_sparse_pairs + compiled_config.origin_dense_pairs
processed_var_names = set(["@EMPTY@"])
param_memory_size = 0
for param_grad_pair in param_grad_pairs:
param, grad = param_grad_pair
param_memory_size += param.m_size
processed_var_names.add(param.name)
upper_mem_use = param_memory_size * 5.0
program_tmp_vars = dict()
batch_size = 1024
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
continue
processed_var_names.add(var_name)
var = program.global_block().vars[var_name]
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
continue
data_count = 1
neg_dim_count = 0
for x in var.shape:
if x < 0:
if neg_dim_count >= 1:
raise ValueError(
"Var %s has more than one negative dim." %
(var_name))
neg_dim_count += 1
data_count *= (-x)
else:
data_count *= x
program_tmp_vars[var_name] = (data_count, neg_dim_count,
dtype_to_size[var.dtype])
for varname in program_tmp_vars:
data_count, neg_dim_count, type_size = program_tmp_vars[varname]
if neg_dim_count == 1:
data_count *= batch_size
var_memory = data_count * type_size
upper_mem_use += var_memory
if upper_mem_use < free:
# auto geo
a_sync_configs["k_steps"] = 800
else:
# auto async
a_sync_configs["k_steps"] = 0
self.user_defined_strategy.a_sync_configs = a_sync_configs
def minimize_impl(self, def minimize_impl(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -134,7 +240,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -134,7 +240,6 @@ class ParameterServerOptimizer(MetaOptimizerBase):
no_grad_set=None): no_grad_set=None):
self.inner_opt.minimize(loss, startup_program, parameter_list, self.inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set) no_grad_set)
strategy = self._get_distributed_strategy()
_origin_main_program = loss.block.program _origin_main_program = loss.block.program
_origin_startup_program = startup_program _origin_startup_program = startup_program
...@@ -142,7 +247,12 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -142,7 +247,12 @@ class ParameterServerOptimizer(MetaOptimizerBase):
compiled_config = public.CompileTimeStrategy(_origin_main_program, compiled_config = public.CompileTimeStrategy(_origin_main_program,
_origin_startup_program, _origin_startup_program,
strategy, self.role_maker) None, self.role_maker)
self._try_auto_apply_geo(_origin_main_program, compiled_config)
strategy = self._get_distributed_strategy()
compiled_config.strategy = strategy
if self.role_maker.is_worker() or self.role_maker._is_heter_worker(): if self.role_maker.is_worker() or self.role_maker._is_heter_worker():
main_program, startup_program = self._build_trainer_programs( main_program, startup_program = self._build_trainer_programs(
......
...@@ -37,7 +37,7 @@ import warnings ...@@ -37,7 +37,7 @@ import warnings
import inspect import inspect
import numpy as np import numpy as np
import paddle
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import utils from paddle.fluid.layers import utils
from ... import unique_name from ... import unique_name
...@@ -56,7 +56,8 @@ __all__ = [ ...@@ -56,7 +56,8 @@ __all__ = [
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'sparse_embedding', 'partial_sum', 'tdm_child', 'rank_attention', 'sparse_embedding', 'partial_sum', 'tdm_child', 'rank_attention',
'tdm_sampler', 'batch_fc', '_pull_box_extended_sparse', 'bilateral_slice' 'tdm_sampler', 'batch_fc', '_pull_box_extended_sparse', 'bilateral_slice',
'correlation'
] ]
...@@ -1546,3 +1547,81 @@ def bilateral_slice(x, guide, grid, has_offset, name=None): ...@@ -1546,3 +1547,81 @@ def bilateral_slice(x, guide, grid, has_offset, name=None):
attrs={'has_offset': has_offset}, attrs={'has_offset': has_offset},
outputs={'Out': out}) outputs={'Out': out})
return out return out
def correlation(x,
y,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_type_multiply=1):
"""
This operation compute correlation of two tensor.
For more information of correlation, please refer to PWC-Net:
CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
<https://arxiv.org/pdf/1709.02371.pdf>_
Args:
x(Tensor): The input x is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64.
y(Tensor): The input y is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64.
pad_size(int): Pad size. The data type is int.
max_displacement(int): Max displacement. The data type is int.
stride1(int): stride size of x. The data type is int.
stride2(int): stride size of y. The data type is int.
corr_type_multiply(int, optional): The type of multiply. The data type is int. Default: 1.
Returns:
Tensor: The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x1 = fluid.layers.data(name='x1',
shape=x_shape,
dtype=x_type,
append_batch_size=False)
x2 = fluid.layers.data(name='x2',
shape=x_shape,
dtype=x_type,
append_batch_size=False)
out = fluid.contrib.correlation(
x1,
x2,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
"""
helper = LayerHelper("correlation", **locals())
output = helper.create_variable_for_type_inference(dtype=x.dtype)
if paddle.fluid.in_dygraph_mode():
attrs = ("pad_size", pad_size, "kernel_size", kernel_size,
"max_displacement", max_displacement, "stride1", stride1,
"stride2", stride2, "corr_type_multiply", corr_type_multiply)
output = getattr(core.ops, "correlation")(x, y, *attrs)
else:
helper.append_op(
type="correlation",
inputs={"Input1": x,
"Input2": y},
attrs={
"pad_size": pad_size,
"kernel_size": kernel_size,
"max_displacement": max_displacement,
"stride1": stride1,
"stride2": stride2,
"corr_type_multiply": corr_type_multiply
},
outputs={"Output": output})
return output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
def corr(x_1,
x_2,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1,
corr_multiply=1):
K = kernel_size
rinput1 = np.pad(x_1, ((0, 0), (0, 0), (pad_size, pad_size),
(pad_size, pad_size)),
mode='constant')
rinput2 = np.pad(x_2, ((0, 0), (0, 0), (pad_size, pad_size),
(pad_size, pad_size)),
mode='constant')
rinput1 = np.transpose(rinput1, (0, 2, 3, 1))
rinput2 = np.transpose(rinput2, (0, 2, 3, 1))
B = int(rinput1.shape[0])
H = int(x_1.shape[2])
W = int(x_2.shape[3])
d = max_displacement
D = 2 * d + 1
output = np.zeros((B, D * D, H, W), dtype=np.float32)
for b in range(B):
for i in range(H):
for j in range(W):
for k in range(-d, d + 1):
for l in range(-d, d + 1):
x1_index = i + pad_size
y1_index = j + pad_size
x2_index = x1_index + k
y2_index = y1_index + l
output[b, l + d + D * (k + d), i, j] = np.mean(
rinput1[b, x1_index:x1_index + K, y1_index:y1_index
+ K] * rinput2[b, x2_index:x2_index + K,
y2_index:y2_index + K])
return output
class TestCorrelationOp(unittest.TestCase):
def test_check_output(self):
if not fluid.core.is_compiled_with_cuda():
return
np.random.seed(13)
np.set_printoptions(threshold=np.inf)
x_shape = (2, 10, 3, 3)
x_type = 'float32'
x1 = fluid.layers.data(
name='x1',
shape=x_shape,
dtype=x_type,
append_batch_size=False,
stop_gradient=False)
x2 = fluid.layers.data(
name='x2',
shape=x_shape,
dtype=x_type,
append_batch_size=False,
stop_gradient=False)
x1_np = np.random.randn(2, 3, 4, 5).astype(x_type)
x2_np = np.random.randn(2, 3, 4, 5).astype(x_type)
out_np = corr(
x1_np,
x2_np,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
out = fluid.contrib.correlation(
x1,
x2,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
loss = fluid.layers.reduce_mean(out)
optimizer = fluid.optimizer.Momentum(0.0001, 0.9)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
res = exe.run(feed={'x1': x1_np,
'x2': x2_np},
fetch_list=[out.name, loss.name])
self.assertTrue(np.allclose(res[0], out_np))
class Net(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(Net, self).__init__(name_scope)
def forward(self, x1, x2):
y = fluid.contrib.correlation(
x1,
x2,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
return y
class TestCorrelationOpDyGraph(unittest.TestCase):
def test_check_output(self):
if not fluid.core.is_compiled_with_cuda():
return
np.random.seed(13)
np.set_printoptions(threshold=np.inf)
x_shape = (2, 10, 3, 3)
x_type = 'float32'
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
x1_np = np.random.randn(2, 3, 4, 5).astype(x_type)
x2_np = np.random.randn(2, 3, 4, 5).astype(x_type)
out_np = corr(
x1_np,
x2_np,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
x1 = to_variable(x1_np)
x2 = to_variable(x2_np)
corr_pd = Net('corr_pd')
y = corr_pd(x1, x2)
out = y.numpy()
self.assertTrue(np.allclose(out, out_np))
if __name__ == '__main__':
unittest.main()
...@@ -135,6 +135,11 @@ class FunctionSpec(object): ...@@ -135,6 +135,11 @@ class FunctionSpec(object):
input_with_spec = pack_sequence_as(args, input_with_spec) input_with_spec = pack_sequence_as(args, input_with_spec)
# If without specificing name in input_spec, add default name
# according to argument name from decorated function.
input_with_spec = replace_spec_empty_name(self._arg_names,
input_with_spec)
return input_with_spec return input_with_spec
@switch_to_static_graph @switch_to_static_graph
...@@ -309,3 +314,61 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -309,3 +314,61 @@ def convert_to_input_spec(inputs, input_spec):
raise TypeError( raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.". "The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec)) type_name(input_spec))
def replace_spec_empty_name(args_name, input_with_spec):
"""
Adds default name according to argument name from decorated function
if without specificing InputSpec.name
The naming rule are as followed:
1. If InputSpec.name is not None, do nothing.
2. If each argument `x` corresponds to an InputSpec, using the argument name like `x`
3. If the arguments `inputs` corresponds to a list(InputSpec), using name like `inputs_0`, `inputs_1`
4. If the arguments `input_dic` corresponds to a dict(InputSpec), using key as name.
For example:
# case 1: foo(x, y)
foo = to_static(foo, input_spec=[InputSpec([None, 10]), InputSpec([None])])
print([in_var.name for in_var in foo.inputs]) # [x, y]
# case 2: foo(inputs) where inputs is a list
foo = to_static(foo, input_spec=[[InputSpec([None, 10]), InputSpec([None])]])
print([in_var.name for in_var in foo.inputs]) # [inputs_0, inputs_1]
# case 3: foo(inputs) where inputs is a dict
foo = to_static(foo, input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None])}])
print([in_var.name for in_var in foo.inputs]) # [x, y]
"""
input_with_spec = list(input_with_spec)
candidate_arg_names = args_name[:len(input_with_spec)]
for i, arg_name in enumerate(candidate_arg_names):
input_spec = input_with_spec[i]
input_with_spec[i] = _replace_spec_name(arg_name, input_spec)
return input_with_spec
def _replace_spec_name(name, input_spec):
"""
Replaces InputSpec.name with given `name` while not specificing it.
"""
if isinstance(input_spec, paddle.static.InputSpec):
if input_spec.name is None:
input_spec.name = name
return input_spec
elif isinstance(input_spec, (list, tuple)):
processed_specs = []
for i, spec in enumerate(input_spec):
new_name = "{}_{}".format(name, i)
processed_specs.append(_replace_spec_name(new_name, spec))
return processed_specs
elif isinstance(input_spec, dict):
processed_specs = {}
for key, spec in six.iteritems(input_spec):
processed_specs[key] = _replace_spec_name(key, spec)
return processed_specs
else:
return input_spec
...@@ -37,7 +37,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator ...@@ -37,7 +37,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
__all__ = [ __all__ = [
'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level', 'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
'set_verbosity' 'set_verbosity', 'save', 'load', 'SaveLoadConfig'
] ]
......
...@@ -217,7 +217,7 @@ def _dygraph_not_support_(func): ...@@ -217,7 +217,7 @@ def _dygraph_not_support_(func):
def _dygraph_only_(func): def _dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
assert in_dygraph_mode( assert in_dygraph_mode(
), "We Only support %s in imperative mode, please use fluid.dygraph.guard() as context to run it in imperative Mode" % func.__name__ ), "We Only support %s in dynamic mode, please call 'paddle.disable_static()' to enter dynamic mode." % func.__name__
return func(*args, **kwargs) return func(*args, **kwargs)
return __impl__ return __impl__
......
...@@ -12,9 +12,22 @@ ...@@ -12,9 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from functools import reduce
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid import core from paddle.fluid import core
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
class VarBlock: class VarBlock:
def __init__(self, varname, offset, size): def __init__(self, varname, offset, size):
...@@ -51,11 +64,14 @@ class VarStruct(object): ...@@ -51,11 +64,14 @@ class VarStruct(object):
self.type = type self.type = type
self.lod_level = lod_level self.lod_level = lod_level
self.persistable = persistable self.persistable = persistable
self.m_size = 1
self.m_size = reduce(lambda x, y: x * y, shape)
self.m_size *= dtype_to_size[dtype]
def __str__(self): def __str__(self):
return "N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}".format( return "N: {}, S: {}, D: {}, T: {}, LL: {}, P: {}, M: {}".format(
self.name, self.shape, self.dtype, self.type, self.lod_level, self.name, self.shape, self.dtype, self.type, self.lod_level,
self.persistable) self.persistable, self.m_size)
class VarDistributed(object): class VarDistributed(object):
......
...@@ -20,7 +20,8 @@ from __future__ import print_function ...@@ -20,7 +20,8 @@ from __future__ import print_function
from .layer_function_generator import generate_layer_fn from .layer_function_generator import generate_layer_fn
from .layer_function_generator import autodoc, templatedoc from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable, in_dygraph_mode
from .. import core
from .loss import softmax_with_cross_entropy from .loss import softmax_with_cross_entropy
from . import tensor from . import tensor
from . import nn from . import nn
...@@ -2893,8 +2894,8 @@ def generate_proposals(scores, ...@@ -2893,8 +2894,8 @@ def generate_proposals(scores,
nms_thresh=0.5, nms_thresh=0.5,
min_size=0.1, min_size=0.1,
eta=1.0, eta=1.0,
name=None, return_rois_num=False,
return_rois_num=False): name=None):
""" """
:alias_main: paddle.nn.functional.generate_proposals :alias_main: paddle.nn.functional.generate_proposals
:alias: paddle.nn.functional.generate_proposals,paddle.nn.functional.vision.generate_proposals :alias: paddle.nn.functional.generate_proposals,paddle.nn.functional.vision.generate_proposals
...@@ -2949,6 +2950,10 @@ def generate_proposals(scores, ...@@ -2949,6 +2950,10 @@ def generate_proposals(scores,
num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents
the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model.
'False' by default. 'False' by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Returns:
tuple: tuple:
A tuple with format ``(rpn_rois, rpn_roi_probs)``. A tuple with format ``(rpn_rois, rpn_roi_probs)``.
...@@ -2969,6 +2974,14 @@ def generate_proposals(scores, ...@@ -2969,6 +2974,14 @@ def generate_proposals(scores,
im_info, anchors, variances) im_info, anchors, variances)
""" """
if in_dygraph_mode():
assert return_rois_num, "return_rois_num should be True in dygraph mode."
attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n,
'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta)
rpn_rois, rpn_roi_probs, rpn_rois_num = core.ops.generate_proposals(
scores, bbox_deltas, im_info, anchors, variances, *attrs)
return rpn_rois, rpn_roi_probs, rpn_rois_num
helper = LayerHelper('generate_proposals', **locals()) helper = LayerHelper('generate_proposals', **locals())
check_variable_and_dtype(scores, 'scores', ['float32'], check_variable_and_dtype(scores, 'scores', ['float32'],
...@@ -2986,7 +2999,14 @@ def generate_proposals(scores, ...@@ -2986,7 +2999,14 @@ def generate_proposals(scores,
dtype=bbox_deltas.dtype) dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_variable_for_type_inference( rpn_roi_probs = helper.create_variable_for_type_inference(
dtype=scores.dtype) dtype=scores.dtype)
rpn_rois_lod = helper.create_variable_for_type_inference(dtype='int32') outputs = {
'RpnRois': rpn_rois,
'RpnRoiProbs': rpn_roi_probs,
}
if return_rois_num:
rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32')
rpn_rois_num.stop_gradient = True
outputs['RpnRoisNum'] = rpn_rois_num
helper.append_op( helper.append_op(
type="generate_proposals", type="generate_proposals",
...@@ -3004,17 +3024,12 @@ def generate_proposals(scores, ...@@ -3004,17 +3024,12 @@ def generate_proposals(scores,
'min_size': min_size, 'min_size': min_size,
'eta': eta 'eta': eta
}, },
outputs={ outputs=outputs)
'RpnRois': rpn_rois,
'RpnRoiProbs': rpn_roi_probs,
'RpnRoisLod': rpn_rois_lod
})
rpn_rois.stop_gradient = True rpn_rois.stop_gradient = True
rpn_roi_probs.stop_gradient = True rpn_roi_probs.stop_gradient = True
rpn_rois_lod.stop_gradient = True
if return_rois_num: if return_rois_num:
return rpn_rois, rpn_roi_probs, rpn_rois_lod return rpn_rois, rpn_roi_probs, rpn_rois_num
else: else:
return rpn_rois, rpn_roi_probs return rpn_rois, rpn_roi_probs
...@@ -3656,6 +3671,7 @@ def distribute_fpn_proposals(fpn_rois, ...@@ -3656,6 +3671,7 @@ def distribute_fpn_proposals(fpn_rois,
max_level, max_level,
refer_level, refer_level,
refer_scale, refer_scale,
rois_num=None,
name=None): name=None):
""" """
:alias_main: paddle.nn.functional.distribute_fpn_proposals :alias_main: paddle.nn.functional.distribute_fpn_proposals
...@@ -3687,6 +3703,11 @@ def distribute_fpn_proposals(fpn_rois, ...@@ -3687,6 +3703,11 @@ def distribute_fpn_proposals(fpn_rois,
come from. come from.
refer_level(int32): The referring level of FPN layer with specified scale. refer_level(int32): The referring level of FPN layer with specified scale.
refer_scale(int32): The referring scale of FPN layer with specified level. refer_scale(int32): The referring scale of FPN layer with specified level.
rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image.
The shape is [B] and data type is int32. B is the number of images.
If it is not None then return a list of 1-D Tensor. Each element
is the output RoIs' number of each image on the corresponding level
and the shape is [B]. None by default.
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
...@@ -3702,6 +3723,10 @@ def distribute_fpn_proposals(fpn_rois, ...@@ -3702,6 +3723,10 @@ def distribute_fpn_proposals(fpn_rois,
the number of total rois. The data type is int32. It is the number of total rois. The data type is int32. It is
used to restore the order of fpn_rois. used to restore the order of fpn_rois.
rois_num_per_level(List): A list of 1-D Tensor and each Tensor is
the RoIs' number in each image on the corresponding level. The shape
is [B] and data type of int32. B is the number of images
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -3716,26 +3741,52 @@ def distribute_fpn_proposals(fpn_rois, ...@@ -3716,26 +3741,52 @@ def distribute_fpn_proposals(fpn_rois,
refer_level=4, refer_level=4,
refer_scale=224) refer_scale=224)
""" """
num_lvl = max_level - min_level + 1
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level',
refer_level, 'refer_scale', refer_scale)
multi_rois, restore_ind, rois_num_per_level = core.ops.distribute_fpn_proposals(
fpn_rois, rois_num, num_lvl, num_lvl, *attrs)
return multi_rois, restore_ind, rois_num_per_level
check_variable_and_dtype(fpn_rois, 'fpn_rois', ['float32', 'float64'], check_variable_and_dtype(fpn_rois, 'fpn_rois', ['float32', 'float64'],
'distribute_fpn_proposals') 'distribute_fpn_proposals')
helper = LayerHelper('distribute_fpn_proposals', **locals()) helper = LayerHelper('distribute_fpn_proposals', **locals())
dtype = helper.input_dtype('fpn_rois') dtype = helper.input_dtype('fpn_rois')
num_lvl = max_level - min_level + 1
multi_rois = [ multi_rois = [
helper.create_variable_for_type_inference(dtype) for i in range(num_lvl) helper.create_variable_for_type_inference(dtype) for i in range(num_lvl)
] ]
restore_ind = helper.create_variable_for_type_inference(dtype='int32') restore_ind = helper.create_variable_for_type_inference(dtype='int32')
inputs = {'FpnRois': fpn_rois}
outputs = {
'MultiFpnRois': multi_rois,
'RestoreIndex': restore_ind,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
rois_num_per_level = [
helper.create_variable_for_type_inference(dtype='int32')
for i in range(num_lvl)
]
outputs['MultiLevelRoIsNum'] = rois_num_per_level
helper.append_op( helper.append_op(
type='distribute_fpn_proposals', type='distribute_fpn_proposals',
inputs={'FpnRois': fpn_rois}, inputs=inputs,
outputs={'MultiFpnRois': multi_rois, outputs=outputs,
'RestoreIndex': restore_ind},
attrs={ attrs={
'min_level': min_level, 'min_level': min_level,
'max_level': max_level, 'max_level': max_level,
'refer_level': refer_level, 'refer_level': refer_level,
'refer_scale': refer_scale 'refer_scale': refer_scale
}) })
if rois_num is not None:
return multi_rois, restore_ind, rois_num_per_level
return multi_rois, restore_ind return multi_rois, restore_ind
...@@ -3820,6 +3871,7 @@ def collect_fpn_proposals(multi_rois, ...@@ -3820,6 +3871,7 @@ def collect_fpn_proposals(multi_rois,
min_level, min_level,
max_level, max_level,
post_nms_top_n, post_nms_top_n,
rois_num_per_level=None,
name=None): name=None):
""" """
:alias_main: paddle.nn.functional.collect_fpn_proposals :alias_main: paddle.nn.functional.collect_fpn_proposals
...@@ -3846,6 +3898,12 @@ def collect_fpn_proposals(multi_rois, ...@@ -3846,6 +3898,12 @@ def collect_fpn_proposals(multi_rois,
min_level(int): The lowest level of FPN layer to collect min_level(int): The lowest level of FPN layer to collect
max_level(int): The highest level of FPN layer to collect max_level(int): The highest level of FPN layer to collect
post_nms_top_n(int): The number of selected RoIs post_nms_top_n(int): The number of selected RoIs
rois_num_per_level(list, optional): The List of RoIs' numbers.
Each element is 1-D Tensor which contains the RoIs' number of each
image on each level and the shape is [B] and data type is
int32, B is the number of images. If it is not None then return
a 1-D Tensor contains the output RoIs' number of each image and
the shape is [B]. Default: None
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
...@@ -3856,6 +3914,9 @@ def collect_fpn_proposals(multi_rois, ...@@ -3856,6 +3914,9 @@ def collect_fpn_proposals(multi_rois,
fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is
float32 or float64. Selected RoIs. float32 or float64. Selected RoIs.
rois_num(Tensor): 1-D Tensor contains the RoIs's number of each
image. The shape is [B] and data type is int32. B is the number of
images.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -3879,21 +3940,38 @@ def collect_fpn_proposals(multi_rois, ...@@ -3879,21 +3940,38 @@ def collect_fpn_proposals(multi_rois,
""" """
check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals') check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals')
check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals') check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals')
num_lvl = max_level - min_level + 1
input_rois = multi_rois[:num_lvl]
input_scores = multi_scores[:num_lvl]
if in_dygraph_mode():
assert rois_num_per_level is not None, "rois_num_per_level should not be None in dygraph mode."
attrs = ('post_nms_topN', post_nms_top_n)
output_rois, rois_num = core.ops.collect_fpn_proposals(
input_rois, input_scores, rois_num_per_level, *attrs)
helper = LayerHelper('collect_fpn_proposals', **locals()) helper = LayerHelper('collect_fpn_proposals', **locals())
dtype = helper.input_dtype('multi_rois') dtype = helper.input_dtype('multi_rois')
check_dtype(dtype, 'multi_rois', ['float32', 'float64'], check_dtype(dtype, 'multi_rois', ['float32', 'float64'],
'collect_fpn_proposals') 'collect_fpn_proposals')
num_lvl = max_level - min_level + 1
input_rois = multi_rois[:num_lvl]
input_scores = multi_scores[:num_lvl]
output_rois = helper.create_variable_for_type_inference(dtype) output_rois = helper.create_variable_for_type_inference(dtype)
output_rois.stop_gradient = True output_rois.stop_gradient = True
inputs = {
'MultiLevelRois': input_rois,
'MultiLevelScores': input_scores,
}
outputs = {'FpnRois': output_rois}
if rois_num_per_level is not None:
inputs['MultiLevelRoIsNum'] = rois_num_per_level
rois_num = helper.create_variable_for_type_inference(dtype='int32')
rois_num.stop_gradient = True
outputs['RoisNum'] = rois_num
helper.append_op( helper.append_op(
type='collect_fpn_proposals', type='collect_fpn_proposals',
inputs={ inputs=inputs,
'MultiLevelRois': input_rois, outputs=outputs,
'MultiLevelScores': input_scores
},
outputs={'FpnRois': output_rois},
attrs={'post_nms_topN': post_nms_top_n}) attrs={'post_nms_topN': post_nms_top_n})
if rois_num_per_level is not None:
return output_rois, rois_num
return output_rois return output_rois
...@@ -6862,7 +6862,8 @@ def roi_pool(input, ...@@ -6862,7 +6862,8 @@ def roi_pool(input,
pooled_height=1, pooled_height=1,
pooled_width=1, pooled_width=1,
spatial_scale=1.0, spatial_scale=1.0,
rois_lod=None): rois_num=None,
name=None):
""" """
:alias_main: paddle.nn.functional.roi_pool :alias_main: paddle.nn.functional.roi_pool
:alias: paddle.nn.functional.roi_pool,paddle.nn.functional.vision.roi_pool :alias: paddle.nn.functional.roi_pool,paddle.nn.functional.vision.roi_pool
...@@ -6882,10 +6883,14 @@ def roi_pool(input, ...@@ -6882,10 +6883,14 @@ def roi_pool(input,
Args: Args:
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32 or float64. input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32 or float64.
rois (Variable): ROIs (Regions of Interest) to pool over. 2D-LoDTensor with the shape of [num_rois,4], the lod level is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates. rois (Variable): ROIs (Regions of Interest) to pool over. 2D-LoDTensor with the shape of [num_rois,4], the lod level is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates.
rois_lod (Variable): The lod info of rois. Default: None
pooled_height (int, optional): The pooled output height, data type is int32. Default: 1 pooled_height (int, optional): The pooled output height, data type is int32. Default: 1
pooled_width (int, optional): The pooled output height, data type is int32. Default: 1 pooled_width (int, optional): The pooled output height, data type is int32. Default: 1
spatial_scale (float, optional): Multiplicative spatial scale factor to translate ROI coords from their input scale to the scale used when pooling. Default: 1.0 spatial_scale (float, optional): Multiplicative spatial scale factor to translate ROI coords from their input scale to the scale used when pooling. Default: 1.0
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Returns:
Variable: The pooled feature, 4D-Tensor with the shape of [num_rois, C, pooled_height, pooled_width]. Variable: The pooled feature, 4D-Tensor with the shape of [num_rois, C, pooled_height, pooled_width].
...@@ -6905,11 +6910,11 @@ def roi_pool(input, ...@@ -6905,11 +6910,11 @@ def roi_pool(input,
input_data = np.array([i for i in range(1,17)]).reshape(1,1,4,4).astype(DATATYPE) input_data = np.array([i for i in range(1,17)]).reshape(1,1,4,4).astype(DATATYPE)
roi_data =fluid.create_lod_tensor(np.array([[1., 1., 2., 2.], [1.5, 1.5, 3., 3.]]).astype(DATATYPE),[[2]], place) roi_data =fluid.create_lod_tensor(np.array([[1., 1., 2., 2.], [1.5, 1.5, 3., 3.]]).astype(DATATYPE),[[2]], place)
rois_lod_data = np.array([0, 2]) rois_num_data = np.array([2]).astype('int32')
x = fluid.data(name='input', shape=[None,1,4,4], dtype=DATATYPE) x = fluid.data(name='input', shape=[None,1,4,4], dtype=DATATYPE)
rois = fluid.data(name='roi', shape=[None,4], dtype=DATATYPE) rois = fluid.data(name='roi', shape=[None,4], dtype=DATATYPE)
rois_lod = fluid.data(name='rois_lod', shape=[None], dtype='int64') rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
pool_out = fluid.layers.roi_pool( pool_out = fluid.layers.roi_pool(
input=x, input=x,
...@@ -6917,24 +6922,36 @@ def roi_pool(input, ...@@ -6917,24 +6922,36 @@ def roi_pool(input,
pooled_height=1, pooled_height=1,
pooled_width=1, pooled_width=1,
spatial_scale=1.0, spatial_scale=1.0,
rois_lod=rois_lod) rois_num=rois_num)
exe = fluid.Executor(place) exe = fluid.Executor(place)
out, = exe.run(feed={'input':input_data ,'roi':roi_data, 'rois_lod': rois_lod_data}, fetch_list=[pool_out.name]) out, = exe.run(feed={'input':input_data ,'roi':roi_data, 'rois_num': rois_num_data}, fetch_list=[pool_out.name])
print(out) #array([[[[11.]]], [[[16.]]]], dtype=float32) print(out) #array([[[[11.]]], [[[16.]]]], dtype=float32)
print(np.array(out).shape) # (2, 1, 1, 1) print(np.array(out).shape) # (2, 1, 1, 1)
""" """
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
pool_out, argmaxes = core.ops.roi_pool(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale)
return pool_out, argmaxes
check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool') check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool')
check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool') check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool')
helper = LayerHelper('roi_pool', **locals()) helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32') argmaxes = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op( helper.append_op(
type="roi_pool", type="roi_pool",
inputs={"X": input, inputs=inputs,
"ROIs": rois,
"RoisLod": rois_lod},
outputs={"Out": pool_out, outputs={"Out": pool_out,
"Argmax": argmaxes}, "Argmax": argmaxes},
attrs={ attrs={
...@@ -6952,8 +6969,8 @@ def roi_align(input, ...@@ -6952,8 +6969,8 @@ def roi_align(input,
pooled_width=1, pooled_width=1,
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=-1, sampling_ratio=-1,
name=None, rois_num=None,
rois_lod=None): name=None):
""" """
:alias_main: paddle.nn.functional.roi_align :alias_main: paddle.nn.functional.roi_align
:alias: paddle.nn.functional.roi_align,paddle.nn.functional.vision.roi_align :alias: paddle.nn.functional.roi_align,paddle.nn.functional.vision.roi_align
...@@ -6968,11 +6985,11 @@ def roi_align(input, ...@@ -6968,11 +6985,11 @@ def roi_align(input,
data type is float32 or float64. Given as [[x1, y1, x2, y2], ...], data type is float32 or float64. Given as [[x1, y1, x2, y2], ...],
(x1, y1) is the top left coordinates, and (x2, y2) is the bottom (x1, y1) is the top left coordinates, and (x2, y2) is the bottom
right coordinates. right coordinates.
rois_lod (Variable): The lod info of rois. Default: None
pooled_height (int32, optional): ${pooled_height_comment} Default: 1 pooled_height (int32, optional): ${pooled_height_comment} Default: 1
pooled_width (int32, optional): ${pooled_width_comment} Default: 1 pooled_width (int32, optional): ${pooled_width_comment} Default: 1
spatial_scale (float32, optional): ${spatial_scale_comment} Default: 1.0 spatial_scale (float32, optional): ${spatial_scale_comment} Default: 1.0
sampling_ratio(int32, optional): ${sampling_ratio_comment} Default: -1 sampling_ratio(int32, optional): ${sampling_ratio_comment} Default: -1
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
...@@ -6991,26 +7008,38 @@ def roi_align(input, ...@@ -6991,26 +7008,38 @@ def roi_align(input,
name='data', shape=[None, 256, 32, 32], dtype='float32') name='data', shape=[None, 256, 32, 32], dtype='float32')
rois = fluid.data( rois = fluid.data(
name='rois', shape=[None, 4], dtype='float32') name='rois', shape=[None, 4], dtype='float32')
rois_lod = fluid.data(name='rois_lod', shape=[None], dtype='int64') rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
align_out = fluid.layers.roi_align(input=x, align_out = fluid.layers.roi_align(input=x,
rois=rois, rois=rois,
pooled_height=7, pooled_height=7,
pooled_width=7, pooled_width=7,
spatial_scale=0.5, spatial_scale=0.5,
sampling_ratio=-1, sampling_ratio=-1,
rois_lod=rois_lod) rois_num=rois_num)
""" """
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = core.ops.roi_align(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale,
"sampling_ratio", sampling_ratio)
return align_out
check_variable_and_dtype(input, 'input', ['float32', 'float64'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'roi_align') 'roi_align')
check_variable_and_dtype(rois, 'rois', ['float32', 'float64'], 'roi_align') check_variable_and_dtype(rois, 'rois', ['float32', 'float64'], 'roi_align')
helper = LayerHelper('roi_align', **locals()) helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype) align_out = helper.create_variable_for_type_inference(dtype)
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op( helper.append_op(
type="roi_align", type="roi_align",
inputs={"X": input, inputs=inputs,
"ROIs": rois,
"RoisLod": rois_lod},
outputs={"Out": align_out}, outputs={"Out": align_out},
attrs={ attrs={
"pooled_height": pooled_height, "pooled_height": pooled_height,
...@@ -10850,8 +10879,7 @@ def slice(input, axes, starts, ends): ...@@ -10850,8 +10879,7 @@ def slice(input, axes, starts, ends):
result = [ [2, 3, 4], ] # result = data[0:1, 1:4] result = [ [2, 3, 4], ] # result = data[0:1, 1:4]
Args: Args:
input (Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float16``, ``float32``, ``float64``, ``int32`` or ``int64``. input (Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float16``, ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to. axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to .
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Variable): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of starts (list|tuple|Variable): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``starts`` is an Variable, it should be an 1-D Tensor. it should be integers or Tensors with shape [1]. If ``starts`` is an Variable, it should be an 1-D Tensor.
It represents starting indices of corresponding axis in ``axes``. It represents starting indices of corresponding axis in ``axes``.
......
...@@ -36,7 +36,7 @@ __all__ = [ ...@@ -36,7 +36,7 @@ __all__ = [
'tensor_array_to_tensor', 'concat', 'sums', 'assign', 'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite',
'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye' 'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye', 'triu'
] ]
...@@ -1725,3 +1725,9 @@ def ones_like(x, out=None): ...@@ -1725,3 +1725,9 @@ def ones_like(x, out=None):
attrs={'value': 1.0}, attrs={'value': 1.0},
outputs={'Out': [out]}) outputs={'Out': [out]})
return out return out
@deprecated(since="2.0.0", update_to="paddle.triu")
def triu(input, diagonal=0, name=None):
import paddle
return paddle.tensor.triu(x=input, diagonal=diagonal, name=name)
...@@ -19,6 +19,57 @@ import paddle.fluid.layers as layers ...@@ -19,6 +19,57 @@ import paddle.fluid.layers as layers
from paddle.fluid.layers import detection from paddle.fluid.layers import detection
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
import unittest import unittest
import contextlib
import numpy as np
from unittests.test_imperative_base import new_program_scope
from paddle.fluid.dygraph import base
from paddle.fluid import core
class LayerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.seed = 111
@classmethod
def tearDownClass(cls):
pass
def _get_place(self, force_to_use_cpu=False):
# this option for ops that only have cpu kernel
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
@contextlib.contextmanager
def static_graph(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
yield
def get_static_graph_result(self,
feed,
fetch_list,
with_lod=False,
force_to_use_cpu=False):
exe = fluid.Executor(self._get_place(force_to_use_cpu))
exe.run(fluid.default_startup_program())
return exe.run(fluid.default_main_program(),
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
@contextlib.contextmanager
def dynamic_graph(self, force_to_use_cpu=False):
with fluid.dygraph.guard(
self._get_place(force_to_use_cpu=force_to_use_cpu)):
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
yield
class TestDetection(unittest.TestCase): class TestDetection(unittest.TestCase):
...@@ -481,45 +532,67 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -481,45 +532,67 @@ class TestRpnTargetAssign(unittest.TestCase):
print(str(program)) print(str(program))
class TestGenerateProposals(unittest.TestCase): class TestGenerateProposals(LayerTest):
def test_generate_proposals(self): def test_generate_proposals(self):
program = Program() scores_np = np.random.rand(2, 3, 4, 4).astype('float32')
with program_guard(program): bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32')
data_shape = [20, 64, 64] im_info_np = np.array([[8, 8, 0.5], [6, 6, 0.5]]).astype('float32')
images = fluid.layers.data( anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4),
name='images', shape=data_shape, dtype='float32') [4, 4, 3, 4]).astype('float32')
im_info = fluid.layers.data( variances_np = np.ones((4, 4, 3, 4)).astype('float32')
name='im_info', shape=[3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator( with self.static_graph():
name='anchor_generator', scores = fluid.data(
input=images, name='scores', shape=[2, 3, 4, 4], dtype='float32')
anchor_sizes=[32, 64], bbox_deltas = fluid.data(
aspect_ratios=[1.0], name='bbox_deltas', shape=[2, 12, 4, 4], dtype='float32')
variance=[0.1, 0.1, 0.2, 0.2], im_info = fluid.data(name='im_info', shape=[2, 3], dtype='float32')
stride=[16.0, 16.0], anchors = fluid.data(
offset=0.5) name='anchors', shape=[4, 4, 3, 4], dtype='float32')
num_anchors = anchors.shape[2] variances = fluid.data(
scores = fluid.layers.data( name='var', shape=[4, 4, 3, 4], dtype='float32')
name='scores', shape=[num_anchors, 8, 8], dtype='float32') rois, roi_probs, rois_num = fluid.layers.generate_proposals(
bbox_deltas = fluid.layers.data( scores,
name='bbox_deltas', bbox_deltas,
shape=[num_anchors * 4, 8, 8], im_info,
dtype='float32') anchors,
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals( variances,
name='generate_proposals', pre_nms_top_n=10,
scores=scores, post_nms_top_n=5,
bbox_deltas=bbox_deltas, return_rois_num=True)
im_info=im_info, rois_stat, roi_probs_stat, rois_num_stat = self.get_static_graph_result(
anchors=anchors, feed={
variances=variances, 'scores': scores_np,
pre_nms_top_n=6000, 'bbox_deltas': bbox_deltas_np,
post_nms_top_n=1000, 'im_info': im_info_np,
nms_thresh=0.5, 'anchors': anchors_np,
min_size=0.1, 'var': variances_np
eta=1.0) },
self.assertIsNotNone(rpn_rois) fetch_list=[rois, roi_probs, rois_num],
self.assertIsNotNone(rpn_roi_probs) with_lod=True)
print(rpn_rois.shape)
with self.dynamic_graph():
scores_dy = base.to_variable(scores_np)
bbox_deltas_dy = base.to_variable(bbox_deltas_np)
im_info_dy = base.to_variable(im_info_np)
anchors_dy = base.to_variable(anchors_np)
variances_dy = base.to_variable(variances_np)
rois, roi_probs, rois_num = fluid.layers.generate_proposals(
scores_dy,
bbox_deltas_dy,
im_info_dy,
anchors_dy,
variances_dy,
pre_nms_top_n=10,
post_nms_top_n=5,
return_rois_num=True)
rois_dy = rois.numpy()
roi_probs_dy = roi_probs.numpy()
rois_num_dy = rois_num.numpy()
self.assertTrue(np.array_equal(np.array(rois_stat), rois_dy))
self.assertTrue(np.array_equal(np.array(roi_probs_stat), roi_probs_dy))
self.assertTrue(np.array_equal(np.array(rois_num_stat), rois_num_dy))
class TestYoloDetection(unittest.TestCase): class TestYoloDetection(unittest.TestCase):
...@@ -648,30 +721,81 @@ class TestMulticlassNMS2(unittest.TestCase): ...@@ -648,30 +721,81 @@ class TestMulticlassNMS2(unittest.TestCase):
self.assertIsNotNone(index) self.assertIsNotNone(index)
class TestCollectFpnPropsals(unittest.TestCase): class TestCollectFpnPropsals(LayerTest):
def test_collect_fpn_proposals(self): def test_collect_fpn_proposals(self):
program = Program() multi_bboxes_np = []
with program_guard(program): multi_scores_np = []
rois_num_per_level_np = []
for i in range(4):
bboxes_np = np.random.rand(5, 4).astype('float32')
scores_np = np.random.rand(5, 1).astype('float32')
rois_num = np.array([2, 3]).astype('int32')
multi_bboxes_np.append(bboxes_np)
multi_scores_np.append(scores_np)
rois_num_per_level_np.append(rois_num)
with self.static_graph():
multi_bboxes = [] multi_bboxes = []
multi_scores = [] multi_scores = []
rois_num_per_level = []
for i in range(4): for i in range(4):
bboxes = layers.data( bboxes = fluid.data(
name='rois' + str(i), name='rois' + str(i),
shape=[10, 4], shape=[5, 4],
dtype='float32', dtype='float32',
lod_level=1, lod_level=1)
append_batch_size=False) scores = fluid.data(
scores = layers.data(
name='scores' + str(i), name='scores' + str(i),
shape=[10, 1], shape=[5, 1],
dtype='float32', dtype='float32',
lod_level=1, lod_level=1)
append_batch_size=False) rois_num = fluid.data(
name='rois_num' + str(i), shape=[None], dtype='int32')
multi_bboxes.append(bboxes) multi_bboxes.append(bboxes)
multi_scores.append(scores) multi_scores.append(scores)
fpn_rois = layers.collect_fpn_proposals(multi_bboxes, multi_scores, rois_num_per_level.append(rois_num)
2, 5, 10)
self.assertIsNotNone(fpn_rois) fpn_rois, rois_num = layers.collect_fpn_proposals(
multi_bboxes,
multi_scores,
2,
5,
10,
rois_num_per_level=rois_num_per_level)
feed = {}
for i in range(4):
feed['rois' + str(i)] = multi_bboxes_np[i]
feed['scores' + str(i)] = multi_scores_np[i]
feed['rois_num' + str(i)] = rois_num_per_level_np[i]
fpn_rois_stat, rois_num_stat = self.get_static_graph_result(
feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True)
fpn_rois_stat = np.array(fpn_rois_stat)
rois_num_stat = np.array(rois_num_stat)
with self.dynamic_graph():
multi_bboxes_dy = []
multi_scores_dy = []
rois_num_per_level_dy = []
for i in range(4):
bboxes_dy = base.to_variable(multi_bboxes_np[i])
scores_dy = base.to_variable(multi_scores_np[i])
rois_num_dy = base.to_variable(rois_num_per_level_np[i])
multi_bboxes_dy.append(bboxes_dy)
multi_scores_dy.append(scores_dy)
rois_num_per_level_dy.append(rois_num_dy)
fpn_rois_dy, rois_num_dy = fluid.layers.collect_fpn_proposals(
multi_bboxes_dy,
multi_scores_dy,
2,
5,
10,
rois_num_per_level=rois_num_per_level_dy)
fpn_rois_dy = fpn_rois_dy.numpy()
rois_num_dy = rois_num_dy.numpy()
self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy))
self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy))
def test_collect_fpn_proposals_error(self): def test_collect_fpn_proposals_error(self):
def generate_input(bbox_type, score_type, name): def generate_input(bbox_type, score_type, name):
...@@ -717,20 +841,51 @@ class TestCollectFpnPropsals(unittest.TestCase): ...@@ -717,20 +841,51 @@ class TestCollectFpnPropsals(unittest.TestCase):
post_nms_top_n=2000) post_nms_top_n=2000)
class TestDistributeFpnProposals(unittest.TestCase): class TestDistributeFpnProposals(LayerTest):
def test_distribute_fpn_proposals(self): def test_distribute_fpn_proposals(self):
program = Program() rois_np = np.random.rand(10, 4).astype('float32')
with program_guard(program): rois_num_np = np.array([4, 6]).astype('int32')
fpn_rois = fluid.layers.data( with self.static_graph():
name='data', shape=[4], dtype='float32', lod_level=1) rois = fluid.data(name='rois', shape=[10, 4], dtype='float32')
multi_rois, restore_ind = layers.distribute_fpn_proposals( rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
fpn_rois=fpn_rois, multi_rois, restore_ind, rois_num_per_level = layers.distribute_fpn_proposals(
fpn_rois=rois,
min_level=2, min_level=2,
max_level=5, max_level=5,
refer_level=4, refer_level=4,
refer_scale=224) refer_scale=224,
self.assertIsNotNone(multi_rois) rois_num=rois_num)
self.assertIsNotNone(restore_ind) fetch_list = multi_rois + [restore_ind] + rois_num_per_level
output_stat = self.get_static_graph_result(
feed={'rois': rois_np,
'rois_num': rois_num_np},
fetch_list=fetch_list,
with_lod=True)
output_stat_np = []
for output in output_stat:
output_np = np.array(output)
if len(output_np) > 0:
output_stat_np.append(output_np)
with self.dynamic_graph():
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = layers.distribute_fpn_proposals(
fpn_rois=rois_dy,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224,
rois_num=rois_num_dy)
output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy
output_dy_np = []
for output in output_dy:
output_np = output.numpy()
if len(output_np) > 0:
output_dy_np.append(output_np)
for res_stat, res_dy in zip(output_stat_np, output_dy_np):
self.assertTrue(np.array_equal(res_stat, res_dy))
def test_distribute_fpn_proposals_error(self): def test_distribute_fpn_proposals_error(self):
program = Program() program = Program()
......
...@@ -440,6 +440,8 @@ if(WITH_DISTRIBUTE) ...@@ -440,6 +440,8 @@ if(WITH_DISTRIBUTE)
# FIXME(seiriosX) will fix this # FIXME(seiriosX) will fix this
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_sparse_embedding_ctr") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_sparse_embedding_ctr")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_gloo") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_gloo")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_a_sync_optimizer_auto")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr")
py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS}) py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS})
py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS}) py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS})
......
...@@ -47,8 +47,8 @@ class SimpleNet(Layer): ...@@ -47,8 +47,8 @@ class SimpleNet(Layer):
return z return z
@declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]]) @declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]])
def func_with_list(self, l): def func_with_list(self, l, int_val=1):
x, y, int_val = l x, y = l
z = x + y z = x + y
z = z + int_val z = z + int_val
return z return z
...@@ -60,10 +60,7 @@ class SimpleNet(Layer): ...@@ -60,10 +60,7 @@ class SimpleNet(Layer):
def func_with_dict(self, d): def func_with_dict(self, d):
x = d['x'] x = d['x']
y = d['y'] y = d['y']
int_val = d['int_val']
z = x + y z = x + y
z = z + int_val
return z return z
...@@ -131,10 +128,10 @@ class TestInputSpec(unittest.TestCase): ...@@ -131,10 +128,10 @@ class TestInputSpec(unittest.TestCase):
self.assertTrue(len(net.add_func.program_cache) == 1) self.assertTrue(len(net.add_func.program_cache) == 1)
# 5. test input with list # 5. test input with list
out = net.func_with_list([x, y, int_val]) out = net.func_with_list([x, y], int_val)
# 6. test input with dict # 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y, 'int_val': int_val}) out = net.func_with_dict({'x': x, 'y': y})
# 7. test input with lits contains dict # 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32') int_np = np.ones([1]).astype('float32')
...@@ -293,6 +290,30 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -293,6 +290,30 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
foo_3.concrete_program foo_3.concrete_program
class TestInputDefaultName(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.net = SimpleNet()
def assert_default_name(self, func_name, input_names):
decorated_func = getattr(self.net, func_name)
spec_names = [x.name for x in decorated_func.inputs]
self.assertListEqual(spec_names, input_names)
def test_common_input(self):
self.assert_default_name('forward', ['x'])
def test_list_input(self):
self.assert_default_name('func_with_list', ['l_0', 'l_1'])
def test_dict_input(self):
self.assert_default_name('func_with_dict', ['x', 'y'])
def test_nest_input(self):
self.assert_default_name('func_with_list_dict', ['dl_0', 'x', 'y'])
class TestDeclarativeAPI(unittest.TestCase): class TestDeclarativeAPI(unittest.TestCase):
def test_error(self): def test_error(self):
func = declarative(dyfunc_to_variable) func = declarative(dyfunc_to_variable)
......
...@@ -33,10 +33,14 @@ class TestCollectFPNProposalstOp(OpTest): ...@@ -33,10 +33,14 @@ class TestCollectFPNProposalstOp(OpTest):
for i in range(self.num_level)] for i in range(self.num_level)]
self.inputs = { self.inputs = {
'MultiLevelRois': inputs_x, 'MultiLevelRois': inputs_x,
"MultiLevelScores": self.scores_input "MultiLevelScores": self.scores_input,
'MultiLevelRoIsNum': []
} }
self.attrs = {'post_nms_topN': self.post_nms_top_n, } self.attrs = {'post_nms_topN': self.post_nms_top_n, }
self.outputs = {'FpnRois': (self.rois, [self.lod])} self.outputs = {
'FpnRois': (self.rois, [self.lod]),
'RoisNum': np.array(self.lod).astype('int32')
}
def init_test_case(self): def init_test_case(self):
self.post_nms_top_n = 20 self.post_nms_top_n = 20
...@@ -96,5 +100,32 @@ class TestCollectFPNProposalstOp(OpTest): ...@@ -96,5 +100,32 @@ class TestCollectFPNProposalstOp(OpTest):
self.check_output(check_dygraph=False) self.check_output(check_dygraph=False)
class TestCollectFPNProposalstOpWithRoisNum(TestCollectFPNProposalstOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.scores_input = [('y%d' % i,
(self.scores[i].reshape(-1, 1), self.rois_lod[i]))
for i in range(self.num_level)]
self.rois, self.lod = self.calc_rois_collect()
inputs_x = [('x%d' % i, (self.roi_inputs[i][:, 1:], self.rois_lod[i]))
for i in range(self.num_level)]
rois_num_per_level = [
('rois%d' % i, np.array(self.rois_lod[i][0]).astype('int32'))
for i in range(self.num_level)
]
self.inputs = {
'MultiLevelRois': inputs_x,
"MultiLevelScores": self.scores_input,
'MultiLevelRoIsNum': rois_num_per_level
}
self.attrs = {'post_nms_topN': self.post_nms_top_n, }
self.outputs = {
'FpnRois': (self.rois, [self.lod]),
'RoisNum': np.array(self.lod).astype('int32')
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
import paddle.distributed.fleet.base.role_maker as role_maker
import time
class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_PSERVER_NUMS"] = "2"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
def test_a_sync_optimizer1(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
def test_a_sync_optimizer2(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 800)
def test_a_sync_optimizer3(self):
os.environ["TRAINING_ROLE"] = "TRAINER"
import paddle.distributed.fleet as fleet
main_program = paddle.fluid.Program()
startup_program = paddle.fluid.Program()
paddle.fluid.framework.switch_main_program(main_program)
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
x_embedding = paddle.fluid.layers.embedding(
is_distributed=False,
input=input_x,
size=[1000000000, 100000],
param_attr=paddle.fluid.ParamAttr(
name="embedding",
initializer=paddle.fluid.initializer.Constant(value=0.01)),
is_sparse=True)
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=x_embedding, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.auto = True
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
self.assertTrue(optimizer.user_defined_strategy.a_sync)
a_sync_configs = optimizer.user_defined_strategy.a_sync_configs
self.assertTrue(a_sync_configs['k_steps'] == 0)
if __name__ == "__main__":
unittest.main()
...@@ -76,9 +76,10 @@ class FleetDistRunnerBase(object): ...@@ -76,9 +76,10 @@ class FleetDistRunnerBase(object):
return role return role
def build_strategy(self, args): def build_strategy(self, args):
if args.mode == "sync":
self.strategy = paddle.distributed.fleet.DistributedStrategy() self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = False self.strategy.a_sync = False
if args.mode == "async": elif args.mode == "async":
self.strategy = paddle.distributed.fleet.DistributedStrategy() self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = True self.strategy.a_sync = True
elif args.mode == "geo": elif args.mode == "geo":
...@@ -87,6 +88,10 @@ class FleetDistRunnerBase(object): ...@@ -87,6 +88,10 @@ class FleetDistRunnerBase(object):
self.strategy.a_sync_configs = { self.strategy.a_sync_configs = {
"k_steps": args.geo_sgd_need_push_nums "k_steps": args.geo_sgd_need_push_nums
} }
elif args.mode == "auto":
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.auto = True
self.dump_param = os.getenv("dump_param", "").split(",") self.dump_param = os.getenv("dump_param", "").split(",")
self.dump_fields = os.getenv("dump_fields", "").split(",") self.dump_fields = os.getenv("dump_fields", "").split(",")
self.dump_fields_path = os.getenv("dump_fields_path", "") self.dump_fields_path = os.getenv("dump_fields_path", "")
...@@ -232,14 +237,17 @@ class TestFleetBase(unittest.TestCase): ...@@ -232,14 +237,17 @@ class TestFleetBase(unittest.TestCase):
tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+") tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+")
tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+") tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+")
tr0_out = open(tempfile.gettempdir() + "/tr0_stdout.log", "wb+")
tr1_out = open(tempfile.gettempdir() + "/tr1_stdout.log", "wb+")
tr0_proc = subprocess.Popen( tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "), tr0_cmd.strip().split(" "),
stdout=subprocess.PIPE, stdout=tr0_out,
stderr=tr0_pipe, stderr=tr0_pipe,
env=required_envs) env=required_envs)
tr1_proc = subprocess.Popen( tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "), tr1_cmd.strip().split(" "),
stdout=subprocess.PIPE, stdout=tr1_out,
stderr=tr1_pipe, stderr=tr1_pipe,
env=required_envs) env=required_envs)
......
...@@ -52,6 +52,38 @@ class TestDistMnistSync2x2(TestFleetBase): ...@@ -52,6 +52,38 @@ class TestDistMnistSync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAuto2x2(TestFleetBase):
def _setup_config(self):
self._mode = "auto"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistMnistAsync2x2(TestFleetBase): class TestDistMnistAsync2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "async" self._mode = "async"
......
...@@ -35,9 +35,10 @@ class TestDistributeFPNProposalsOp(OpTest): ...@@ -35,9 +35,10 @@ class TestDistributeFPNProposalsOp(OpTest):
} }
output = [('out%d' % i, self.rois_fpn[i]) output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))] for i in range(len(self.rois_fpn))]
self.outputs = { self.outputs = {
'MultiFpnRois': output, 'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1) 'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
} }
def init_test_case(self): def init_test_case(self):
...@@ -117,5 +118,34 @@ class TestDistributeFPNProposalsOp(OpTest): ...@@ -117,5 +118,34 @@ class TestDistributeFPNProposalsOp(OpTest):
self.check_output() self.check_output()
class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {
'FpnRois': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': np.array(self.rois_lod[0]).astype('int32')
}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level
}
output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))]
rois_num_per_level = [
('rois_num%d' % i, np.array(self.rois_fpn[i][1][0]).astype('int32'))
for i in range(len(self.rois_fpn))
]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
'MultiLevelRoIsNum': rois_num_per_level
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -34,18 +34,18 @@ def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors, ...@@ -34,18 +34,18 @@ def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors,
rpn_rois = [] rpn_rois = []
rpn_roi_probs = [] rpn_roi_probs = []
lod = [] rois_num = []
num_images = scores.shape[0] num_images = scores.shape[0]
for img_idx in range(num_images): for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image( img_i_boxes, img_i_probs = proposal_for_one_image(
im_info[img_idx, :], all_anchors, variances, im_info[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :], bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta) pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
lod.append(img_i_probs.shape[0]) rois_num.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes) rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs) rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, lod return rpn_rois, rpn_roi_probs, rois_num
def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores, def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
...@@ -87,6 +87,10 @@ def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores, ...@@ -87,6 +87,10 @@ def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
proposals = clip_tiled_boxes(proposals, im_info[:2]) proposals = clip_tiled_boxes(proposals, im_info[:2])
# remove predicted boxes with height or width < min_size # remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_info) keep = filter_boxes(proposals, min_size, im_info)
if len(keep) == 0:
proposals = np.zeros((1, 4)).astype('float32')
scores = np.zeros((1, 1)).astype('float32')
return proposals, scores
proposals = proposals[keep, :] proposals = proposals[keep, :]
scores = scores[keep, :] scores = scores[keep, :]
...@@ -280,8 +284,8 @@ class TestGenerateProposalsOp(OpTest): ...@@ -280,8 +284,8 @@ class TestGenerateProposalsOp(OpTest):
} }
self.outputs = { self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]), 'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]), 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
} }
def test_check_output(self): def test_check_output(self):
...@@ -320,7 +324,7 @@ class TestGenerateProposalsOp(OpTest): ...@@ -320,7 +324,7 @@ class TestGenerateProposalsOp(OpTest):
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32') (batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
def init_test_output(self): def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.lod = generate_proposals_in_python( self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_in_python(
self.scores, self.bbox_deltas, self.im_info, self.anchors, self.scores, self.bbox_deltas, self.im_info, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN, self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta) self.nms_thresh, self.min_size, self.eta)
...@@ -349,12 +353,21 @@ class TestGenerateProposalsOutLodOp(TestGenerateProposalsOp): ...@@ -349,12 +353,21 @@ class TestGenerateProposalsOutLodOp(TestGenerateProposalsOp):
} }
self.outputs = { self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]), 'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]), 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisLod': (np.asarray( 'RpnRoisNum': (np.asarray(
self.lod, dtype=np.int32)) self.rois_num, dtype=np.int32))
} }
class TestGenerateProposalsOpNoBoxLeft(TestGenerateProposalsOp):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 1000.0
self.eta = 1.
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -100,7 +100,7 @@ def add_cases(suite): ...@@ -100,7 +100,7 @@ def add_cases(suite):
GridSampleTestCase( GridSampleTestCase(
methodName='runTest', methodName='runTest',
mode='bilinear', mode='bilinear',
padding_mode='reflect', padding_mode='reflection',
align_corners=True)) align_corners=True))
suite.addTest( suite.addTest(
GridSampleTestCase( GridSampleTestCase(
......
...@@ -73,7 +73,7 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): ...@@ -73,7 +73,7 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
if padding_mode == "border": if padding_mode == "border":
grid_slice = clip(grid_slice, 0, max_val) grid_slice = clip(grid_slice, 0, max_val)
elif padding_mode == "reflect": elif padding_mode == "reflection":
double_range = 2 * max_val if align_corners else (max_val + 1) * 2 double_range = 2 * max_val if align_corners else (max_val + 1) * 2
grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
0.5) 0.5)
...@@ -211,7 +211,7 @@ class Case2(TestGridSamplerOp): ...@@ -211,7 +211,7 @@ class Case2(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
...@@ -221,7 +221,7 @@ class Case3(TestGridSamplerOp): ...@@ -221,7 +221,7 @@ class Case3(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = True self.align_corners = True
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
...@@ -231,7 +231,7 @@ class Case4(TestGridSamplerOp): ...@@ -231,7 +231,7 @@ class Case4(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "nearest" self.mode = "nearest"
self.numeric_grad_delta = 0.0001 self.numeric_grad_delta = 0.0001
......
...@@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction): ...@@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction):
loss = np.where(target >= 0, output, np.zeros_like(x)) loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean": if reduction == "batchmean":
if len(x.shape) > 0:
return loss.sum() / x.shape[0] return loss.sum() / x.shape[0]
else:
return loss.sum()
if reduction == "mean": if reduction == "mean":
return loss.mean() return loss.mean()
if reduction == "sum": if reduction == "sum":
...@@ -93,6 +96,9 @@ class TestKLDivLossDygraph(unittest.TestCase): ...@@ -93,6 +96,9 @@ class TestKLDivLossDygraph(unittest.TestCase):
def test_kl_loss_batchmean(self): def test_kl_loss_batchmean(self):
self.run_kl_loss('batchmean') self.run_kl_loss('batchmean')
def test_kl_loss_batchmean_shape(self):
self.run_kl_loss('batchmean', ())
def test_kl_loss_mean(self): def test_kl_loss_mean(self):
self.run_kl_loss('mean') self.run_kl_loss('mean')
......
...@@ -3318,15 +3318,29 @@ class TestBook(LayerTest): ...@@ -3318,15 +3318,29 @@ class TestBook(LayerTest):
return (out) return (out)
def test_roi_pool(self): def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now x_np = np.random.rand(2, 3, 8, 8).astype('float32')
rois_np = np.random.rand(3, 4).astype('float32')
rois_num_np = np.array([1, 2]).astype('int32')
with self.static_graph(): with self.static_graph():
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") x = layers.data(name="x", shape=[3, 8, 8], dtype="float32")
rois = layers.data( rois = layers.data(name="rois", shape=[4], dtype="float32")
name="rois", shape=[4], dtype="float32", lod_level=1) rois_num = fluid.data(name="rois_num", shape=[None], dtype="int32")
rois_lod = layers.data( output = layers.roi_pool(x, rois, 4, 4, 0.5, rois_num=rois_num)
name="rois_lod", shape=[None, ], dtype="int", lod_level=1) static_res = self.get_static_graph_result(
output = layers.roi_pool(x, rois, 7, 7, 0.6, rois_lod) feed={'x': x_np,
return (output) 'rois': rois_np,
'rois_num': rois_num_np},
fetch_list=[output])[0]
with self.dynamic_graph():
x_dy = base.to_variable(x_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
dy_res = layers.roi_pool(
x_dy, rois_dy, 4, 4, 0.5, rois_num=rois_num_dy)
dy_res_value = dy_res[0].numpy()
self.assertTrue(np.array_equal(static_res, dy_res_value))
def test_sequence_enumerate(self): def test_sequence_enumerate(self):
# TODO(minqiyang): dygraph do not support lod now # TODO(minqiyang): dygraph do not support lod now
...@@ -3335,16 +3349,29 @@ class TestBook(LayerTest): ...@@ -3335,16 +3349,29 @@ class TestBook(LayerTest):
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
def test_roi_align(self): def test_roi_align(self):
# TODO(minqiyang): dygraph do not support lod now x_np = np.random.rand(2, 3, 8, 8).astype('float32')
rois_np = np.random.rand(3, 4).astype('float32')
rois_num_np = np.array([1, 2]).astype('int32')
with self.static_graph(): with self.static_graph():
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") x = layers.data(name="x", shape=[3, 8, 8], dtype="float32")
rois = layers.data( rois = layers.data(name="rois", shape=[4], dtype="float32")
name="rois", shape=[4], dtype="float32", lod_level=1) rois_num = fluid.data(name="rois_num", shape=[None], dtype="int32")
rois_lod = layers.data( output = layers.roi_align(x, rois, 4, 4, 0.5, 2, rois_num=rois_num)
name="rois_lod", shape=[None, ], dtype="int", lod_level=1) static_res = self.get_static_graph_result(
output = layers.roi_align(x, rois, 14, 14, 0.5, 2, 'roi_align', feed={'x': x_np,
rois_lod) 'rois': rois_np,
return (output) 'rois_num': rois_num_np},
fetch_list=[output])[0]
with self.dynamic_graph():
x_dy = base.to_variable(x_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
dy_res = layers.roi_align(
x_dy, rois_dy, 4, 4, 0.5, 2, rois_num=rois_num_dy)
dy_res_value = dy_res.numpy()
self.assertTrue(np.array_equal(static_res, dy_res_value))
def test_roi_perspective_transform(self): def test_roi_perspective_transform(self):
# TODO(minqiyang): dygraph do not support lod now # TODO(minqiyang): dygraph do not support lod now
......
...@@ -16,20 +16,49 @@ from __future__ import print_function ...@@ -16,20 +16,49 @@ from __future__ import print_function
import unittest import unittest
import paddle
import paddle.nn as nn
import numpy as np
paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase): class EmbeddingDygraph(unittest.TestCase):
def test_1(self): def test_1(self):
import paddle x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
import paddle.nn as nn y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
import numpy as np paddle.disable_static(paddle.CPUPlace())
paddle.disable_static() x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True)
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0)
adam = paddle.optimizer.Adam(
parameters=[embedding.weight], learning_rate=0.01)
adam.clear_grad()
out = embedding(x)
out.backward()
adam.step()
def test_2(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
with self.assertRaises(ValueError):
embedding = paddle.nn.Embedding(10, 3, padding_idx=11, sparse=True)
# example 1 with self.assertRaises(ValueError):
inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64') embedding = paddle.nn.Embedding(-1, 3, sparse=True)
inp_word.shape # [2, 3]
dict_size = 20
emb = nn.Embedding(dict_size, 32, weight_attr='emb.w', sparse=False) with self.assertRaises(ValueError):
embedding = paddle.nn.Embedding(10, -3, sparse=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -73,8 +73,13 @@ class EmbeddingStatic(unittest.TestCase): ...@@ -73,8 +73,13 @@ class EmbeddingStatic(unittest.TestCase):
dtype="int32") dtype="int32")
emb = functional.embedding( emb = functional.embedding(
x=label, weight=weight, sparse=True, name="embedding") x=label,
weight=weight,
padding_idx=129,
sparse=True,
name="embedding")
with self.assertRaises(ValueError):
test_bad_x() test_bad_x()
......
...@@ -101,6 +101,29 @@ class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase): ...@@ -101,6 +101,29 @@ class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestReduceSumWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 11]
eps = 0.05
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.reduce_sum(x, dim=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMulDoubleGradCheck(unittest.TestCase): class TestMulDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
......
...@@ -26,11 +26,11 @@ def p_norm(x, axis, porder, keepdims=False): ...@@ -26,11 +26,11 @@ def p_norm(x, axis, porder, keepdims=False):
if axis is None: if axis is None:
x = x.flatten() x = x.flatten()
if porder == np.inf: if porder == np.inf:
r = np.amax(np.abs(x)) r = np.amax(np.abs(x), keepdims=keepdims)
elif porder == -np.inf: elif porder == -np.inf:
r = np.amin(np.abs(x)) r = np.amin(np.abs(x), keepdims=keepdims)
else: else:
r = np.linalg.norm(x, ord=porder) r = np.linalg.norm(x, ord=porder, keepdims=keepdims)
elif isinstance(axis, list or tuple) and len(axis) == 2: elif isinstance(axis, list or tuple) and len(axis) == 2:
if porder == np.inf: if porder == np.inf:
axis = tuple(axis) axis = tuple(axis)
...@@ -41,10 +41,10 @@ def p_norm(x, axis, porder, keepdims=False): ...@@ -41,10 +41,10 @@ def p_norm(x, axis, porder, keepdims=False):
elif porder == 0: elif porder == 0:
axis = tuple(axis) axis = tuple(axis)
r = x.astype(bool) r = x.astype(bool)
r = np.sum(r, axis) r = np.sum(r, axis, keepdims=keepdims)
elif porder == 1: elif porder == 1:
axis = tuple(axis) axis = tuple(axis)
r = np.sum(np.abs(x), axis) r = np.sum(np.abs(x), axis, keepdims=keepdims)
else: else:
axis = tuple(axis) axis = tuple(axis)
xp = np.power(np.abs(x), porder) xp = np.power(np.abs(x), porder)
...@@ -61,7 +61,7 @@ def p_norm(x, axis, porder, keepdims=False): ...@@ -61,7 +61,7 @@ def p_norm(x, axis, porder, keepdims=False):
def frobenius_norm(x, axis=None, keepdims=False): def frobenius_norm(x, axis=None, keepdims=False):
if isinstance(axis, list): axis = tuple(axis) if isinstance(axis, list): axis = tuple(axis)
if axis is None: axis = (-2, -1) if axis is None: x = x.reshape(1, x.size)
r = np.linalg.norm( r = np.linalg.norm(
x, ord='fro', axis=axis, keepdims=keepdims).astype(x.dtype) x, ord='fro', axis=axis, keepdims=keepdims).astype(x.dtype)
return r return r
...@@ -217,28 +217,37 @@ class TestPnormOp5(TestPnormOp): ...@@ -217,28 +217,37 @@ class TestPnormOp5(TestPnormOp):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
def run_fro(self, p, axis, shape_x, dtype): def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype) data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(x=data, p=p, axis=axis) out = paddle.norm(x=data, p=p, axis=axis, keepdim=keep_dim)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = frobenius_norm(np_input, axis=axis) expected_result = frobenius_norm(np_input, axis=axis, keepdims=keep_dim)
result, = exe.run(feed={"X": np_input}, fetch_list=[out]) result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
if keep_dim and check_dim:
self.assertEqual(
(np.abs(np.array(result.shape) - np.array(expected_result.shape)) <
1e-6).all(), True)
def run_pnorm(self, p, axis, shape_x, dtype): def run_pnorm(self, p, axis, shape_x, dtype, keep_dim, check_dim=False):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype) data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(x=data, p=p, axis=axis) out = paddle.norm(x=data, p=p, axis=axis, keepdim=keep_dim)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype) np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype) expected_result = p_norm(
np_input, porder=p, axis=axis, keepdims=keep_dim).astype(dtype)
result, = exe.run(feed={"X": np_input}, fetch_list=[out]) result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True) self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
if keep_dim and check_dim:
self.assertEqual(
(np.abs(np.array(result.shape) - np.array(expected_result.shape)) <
1e-6).all(), True)
def run_graph(self, p, axis, shape_x, dtype): def run_graph(self, p, axis, shape_x, dtype):
...@@ -253,6 +262,7 @@ def run_graph(self, p, axis, shape_x, dtype): ...@@ -253,6 +262,7 @@ def run_graph(self, p, axis, shape_x, dtype):
# compute frobenius norm along last two dimensions. # compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro') out_fro = paddle.norm(x, p='fro')
out_fro = paddle.norm(x, p='fro', axis=0)
out_fro = paddle.norm(x, p='fro', axis=[0, 1]) out_fro = paddle.norm(x, p='fro', axis=[0, 1])
# compute 2-order norm along [0,1] dimension. # compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0, 1]) out_pnorm = paddle.norm(x, p=2, axis=[0, 1])
...@@ -274,27 +284,133 @@ def run_graph(self, p, axis, shape_x, dtype): ...@@ -274,27 +284,133 @@ def run_graph(self, p, axis, shape_x, dtype):
class API_NormTest(unittest.TestCase): class API_NormTest(unittest.TestCase):
def test_basic(self): def test_basic(self):
run_fro(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32") keep_dims = {False, True}
run_fro(self, p='fro', axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") for keep in keep_dims:
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32") run_fro(
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64") self,
run_pnorm(self, p=np.inf, axis=0, shape_x=[2, 3, 4], dtype="float32") p='fro',
run_pnorm(self, p=np.inf, axis=None, shape_x=[2, 3, 4], dtype="float32") axis=None,
run_pnorm(self, p=-np.inf, axis=0, shape_x=[2, 3, 4], dtype="float64") shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep)
run_fro(
self,
p='fro',
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm( run_pnorm(
self, p=-np.inf, axis=None, shape_x=[2, 3, 4], dtype="float64") self,
run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64") p=2,
axis=None,
run_pnorm(self, p=1, axis=1, shape_x=[3, 4], dtype="float64") shape_x=[3, 4],
run_pnorm(self, p=0, axis=None, shape_x=[3, 4], dtype="float64") dtype="float32",
run_pnorm(self, p=2, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") keep_dim=keep)
run_pnorm(self, p=2, axis=-1, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=1, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm( run_pnorm(
self, p=np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") self,
p=2,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm( run_pnorm(
self, p=-np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64") self,
p=np.inf,
axis=0,
shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=np.inf,
axis=None,
shape_x=[2, 3, 4],
dtype="float32",
keep_dim=keep)
run_pnorm(
self,
p=-np.inf,
axis=0,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=-np.inf,
axis=None,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep)
run_pnorm(
self,
p=0,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=1,
axis=1,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=0,
axis=None,
shape_x=[3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=2,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=2,
axis=-1,
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=1,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=np.inf,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
run_pnorm(
self,
p=-np.inf,
axis=[0, 1],
shape_x=[2, 3, 4],
dtype="float64",
keep_dim=keep,
check_dim=True)
def test_dygraph(self): def test_dygraph(self):
run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32") run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
...@@ -315,6 +431,7 @@ class API_NormTest(unittest.TestCase): ...@@ -315,6 +431,7 @@ class API_NormTest(unittest.TestCase):
paddle.norm(data, p=p, out=out) paddle.norm(data, p=p, out=out)
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "int64") self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "int64")
self.assertRaises(ValueError, paddle.norm, "inf", [2], "int64")
out = fluid.data(name="out", shape=[1], dtype="int64") out = fluid.data(name="out", shape=[1], dtype="int64")
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "float64", self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "float64",
out) out)
...@@ -325,6 +442,7 @@ class API_NormTest(unittest.TestCase): ...@@ -325,6 +442,7 @@ class API_NormTest(unittest.TestCase):
self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm") self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm")
self.assertRaises(ValueError, paddle.norm, data, p=[1]) self.assertRaises(ValueError, paddle.norm, data, p=[1])
self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1) self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1)
self.assertRaises(ValueError, paddle.norm, 0, [1, 0], "float64")
data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64") data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64")
self.assertRaises( self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1]) ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1])
......
...@@ -165,7 +165,6 @@ class TestPool3d_API(unittest.TestCase): ...@@ -165,7 +165,6 @@ class TestPool3d_API(unittest.TestCase):
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_ndhwc_results(self, place): def check_max_dygraph_ndhwc_results(self, place):
print("run ndchw max pool3d")
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable( input = fluid.dygraph.to_variable(
...@@ -190,7 +189,6 @@ class TestPool3d_API(unittest.TestCase): ...@@ -190,7 +189,6 @@ class TestPool3d_API(unittest.TestCase):
np.transpose(result.numpy(), [0, 4, 1, 2, 3]), result_np)) np.transpose(result.numpy(), [0, 4, 1, 2, 3]), result_np))
def check_max_dygraph_ceilmode_results(self, place): def check_max_dygraph_ceilmode_results(self, place):
print("run ceil mode max pool3d")
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np) input = fluid.dygraph.to_variable(input_np)
......
...@@ -181,16 +181,11 @@ class TestROIAlignInLodOp(TestROIAlignOp): ...@@ -181,16 +181,11 @@ class TestROIAlignInLodOp(TestROIAlignOp):
self.calc_roi_align() self.calc_roi_align()
seq_len = self.rois_lod[0] seq_len = self.rois_lod[0]
cur_len = 0
lod = [cur_len]
for l in seq_len:
cur_len += l
lod.append(cur_len)
self.inputs = { self.inputs = {
'X': self.x, 'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod), 'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisLod': np.asarray(lod).astype('int64') 'RoisNum': np.asarray(seq_len).astype('int32')
} }
self.attrs = { self.attrs = {
......
...@@ -174,16 +174,11 @@ class TestROIPoolInLodOp(TestROIPoolOp): ...@@ -174,16 +174,11 @@ class TestROIPoolInLodOp(TestROIPoolOp):
self.calc_roi_pool() self.calc_roi_pool()
seq_len = self.rois_lod[0] seq_len = self.rois_lod[0]
cur_len = 0
lod = [cur_len]
for l in seq_len:
cur_len += l
lod.append(cur_len)
self.inputs = { self.inputs = {
'X': self.x, 'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod), 'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisLod': np.asarray(lod).astype('int64') 'RoisNum': np.asarray(seq_len).astype('int32')
} }
self.attrs = { self.attrs = {
......
...@@ -142,6 +142,18 @@ class TestTrilTriuOpAPI(unittest.TestCase): ...@@ -142,6 +142,18 @@ class TestTrilTriuOpAPI(unittest.TestCase):
self.assertTrue(np.allclose(tril_out, np.tril(data))) self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data))) self.assertTrue(np.allclose(triu_out, np.triu(data)))
def test_fluid_api(self):
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
triu_out = fluid.layers.triu(x)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
triu_out = exe.run(fluid.default_main_program(),
feed={"x": data},
fetch_list=[triu_out])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
import numpy as np import numpy as np
import numbers
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -107,6 +109,11 @@ def summary(net, input_size, batch_size=None, dtypes=None): ...@@ -107,6 +109,11 @@ def summary(net, input_size, batch_size=None, dtypes=None):
if batch_size is None: if batch_size is None:
batch_size = -1 batch_size = -1
if not paddle.in_dynamic_mode():
warnings.warn(
"Your model was created in static mode, this may not get correct summary information!"
)
result, params_info = summary_string(net, _input_size, batch_size, dtypes) result, params_info = summary_string(net, _input_size, batch_size, dtypes)
print(result) print(result)
...@@ -121,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -121,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
depth = len(list(model.sublayers())) depth = len(list(model.sublayers()))
def register_hook(module): def register_hook(layer):
def hook(module, input, output): def hook(layer, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(layer.__class__).split(".")[-1].split("'")[0]
try: try:
module_idx = int(module._full_name.split('_')[-1]) layer_idx = int(layer._full_name.split('_')[-1])
except: except:
module_idx = len(summary) layer_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1) m_key = "%s-%i" % (class_name, layer_idx + 1)
summary[m_key] = OrderedDict() summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].shape) summary[m_key]["input_shape"] = list(input[0].shape)
summary[m_key]["input_shape"][0] = batch_size summary[m_key]["input_shape"][0] = batch_size
...@@ -142,23 +149,50 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -142,23 +149,50 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
summary[m_key]["output_shape"][0] = batch_size summary[m_key]["output_shape"][0] = batch_size
params = 0 params = 0
if hasattr(module, "weight") and hasattr(module.weight, "shape"):
params += np.prod(module.weight.shape) if paddle.in_dynamic_mode():
summary[m_key]["trainable"] = module.weight.trainable or ( layer_state_dict = layer._parameters
not module.weight.stop_gradient) else:
if hasattr(module, "bias") and hasattr(module.bias, "shape"): layer_state_dict = layer.state_dict()
params += np.prod(module.bias.shape)
for k, v in layer_state_dict.items():
params += np.prod(v.shape)
try:
if (getattr(getattr(layer, k), 'trainable')) and (
not getattr(getattr(layer, k), 'stop_gradient')):
summary[m_key]["trainable"] = True
else:
summary[m_key]["trainable"] = False
except:
summary[m_key]["trainable"] = True
summary[m_key]["nb_params"] = params summary[m_key]["nb_params"] = params
if (not isinstance(module, nn.Sequential) and if (not isinstance(layer, nn.Sequential) and
not isinstance(module, nn.LayerList) and not isinstance(layer, nn.LayerList) and
(not (module == model) or depth < 1)): (not (layer == model) or depth < 1)):
hooks.append(layer.register_forward_post_hook(hook))
def _check_input_size(input_sizes):
for input_size in input_sizes:
for item in input_size:
if not isinstance(item, numbers.Number):
raise TypeError(
"Expected item in input size be a number, but got {}".
format(type(item)))
hooks.append(module.register_forward_post_hook(hook)) if item <= 0:
raise ValueError(
"Expected item in input size greater than zero, but got {}".
format(item))
if isinstance(input_size, tuple): if isinstance(input_size, tuple):
input_size = [input_size] input_size = [input_size]
_check_input_size(input_size)
x = [ x = [
paddle.rand( paddle.rand(
[2] + list(in_size), dtype=dtype) [2] + list(in_size), dtype=dtype)
...@@ -197,7 +231,12 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None): ...@@ -197,7 +231,12 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
"{0:,}".format(summary[layer]["nb_params"]), ) "{0:,}".format(summary[layer]["nb_params"]), )
total_params += summary[layer]["nb_params"] total_params += summary[layer]["nb_params"]
try:
total_output += np.prod(summary[layer]["output_shape"]) total_output += np.prod(summary[layer]["output_shape"])
except:
for output_shape in summary[layer]["output_shape"]:
total_output += np.prod(output_shape)
if "trainable" in summary[layer]: if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True: if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"] trainable_params += summary[layer]["nb_params"]
......
...@@ -113,17 +113,18 @@ def one_hot(x, num_classes, name=None): ...@@ -113,17 +113,18 @@ def one_hot(x, num_classes, name=None):
def embedding(x, weight, padding_idx=None, sparse=False, name=None): def embedding(x, weight, padding_idx=None, sparse=False, name=None):
""" """
The operator is used to lookup embeddings vector of ids provided by :attr:`input` . The operator is used to lookup embeddings vector of ids provided by :attr:`x` .
The shape of output Tensor is generated by appending the last dimension of the input Tensor shape The shape of output Tensor is generated by appending the last dimension of the input Tensor shape
with embedding size. with embedding size.
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < weight.shape[0]` ,
**Note:** The id in :attr:`x` must satisfy :math:`0 =< id < weight.shape[0]` ,
otherwise the program will throw an exception and exit. otherwise the program will throw an exception and exit.
.. code-block:: text .. code-block:: text
Case 1: Case 1:
input is a Tensor. x is a Tensor.
padding_idx = -1 padding_idx = -1
x.data = [[1, 3], [2, 4], [4, 127]] x.data = [[1, 3], [2, 4], [4, 127]]
x.shape = [3, 2] x.shape = [3, 2]
...@@ -138,7 +139,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): ...@@ -138,7 +139,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
[0.0, 0.0, ..., 0.0 ]]] # padding data [0.0, 0.0, ..., 0.0 ]]] # padding data
The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127 The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
It will pad all-zero data when ids is 127. It will pad all-zero data when id is 127.
Args: Args:
x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should
...@@ -151,10 +152,10 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): ...@@ -151,10 +152,10 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` ,
:ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` ,
:ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` .
In these cases, is_sparse must be False. Default: False. In these cases, sparse must be False. Default: False.
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size). padding_idx(int|long|None): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None. If set None, it makes no effect to output. Default: None.
name(str|None): For detailed information, please refer name(str|None): For detailed information, please refer
...@@ -162,7 +163,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): ...@@ -162,7 +163,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
None by default. None by default.
Returns: Returns:
Tensor: Embedding Tensor mapped by input. The data type is the same as :attr:`weight`. Tensor: Embedding Tensor mapped by x. The data type is the same as :attr:`weight`.
Examples: Examples:
...@@ -209,6 +210,10 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): ...@@ -209,6 +210,10 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
weight.shape[0] + padding_idx) weight.shape[0] + padding_idx)
if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]:
raise ValueError("padding_idx must be within [-{}, {})".format(
weight.shape[0], weight.shape[0]))
helper.append_op( helper.append_op(
type='lookup_table_v2', type='lookup_table_v2',
inputs={'Ids': x, inputs={'Ids': x,
......
...@@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None):
input = np.random.uniform(-10, 10, shape).astype('float32') input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_tensor(target), reduction='batchmean') paddle.to_tensor(target), reduction='batchmean')
# shape=[5] # shape=[1]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input), pred_loss = F.kl_div(paddle.to_tensor(input),
......
...@@ -389,7 +389,7 @@ def avg_pool3d(x, ...@@ -389,7 +389,7 @@ def avg_pool3d(x,
stride=None, stride=None,
padding=0, padding=0,
ceil_mode=False, ceil_mode=False,
count_include_pad=False, count_include_pad=True,
divisor_override=None, divisor_override=None,
data_format="NCDHW", data_format="NCDHW",
name=None): name=None):
......
...@@ -249,7 +249,7 @@ def grid_sample(x, ...@@ -249,7 +249,7 @@ def grid_sample(x,
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'. mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
Default: 'bilinear'. Default: 'bilinear'.
padding_mode(str, optional) The padding method used when source index padding_mode(str, optional) The padding method used when source index
is out of input images. It can be 'zeros', 'reflect' and 'border'. is out of input images. It can be 'zeros', 'reflection' and 'border'.
Default: zeros. Default: zeros.
align_corners(bool, optional): If `align_corners` is true, it will projects align_corners(bool, optional): If `align_corners` is true, it will projects
-1 and 1 to the centers of the corner pixels. Otherwise, it will -1 and 1 to the centers of the corner pixels. Otherwise, it will
...@@ -312,7 +312,7 @@ def grid_sample(x, ...@@ -312,7 +312,7 @@ def grid_sample(x,
if not isinstance(grid, Variable): if not isinstance(grid, Variable):
raise ValueError("The grid should be a Variable") raise ValueError("The grid should be a Variable")
_modes = ['bilinear', 'nearest'] _modes = ['bilinear', 'nearest']
_padding_modes = ['zeros', 'reflect', 'border'] _padding_modes = ['zeros', 'reflection', 'border']
if mode not in _modes: if mode not in _modes:
raise ValueError( raise ValueError(
"The mode of grid sample function should be in {}, but got: {}". "The mode of grid sample function should be in {}, but got: {}".
......
...@@ -1564,22 +1564,18 @@ class CosineSimilarity(layers.Layer): ...@@ -1564,22 +1564,18 @@ class CosineSimilarity(layers.Layer):
class Embedding(layers.Layer): class Embedding(layers.Layer):
""" """
:alias_main: paddle.nn.Embedding
:alias: paddle.nn.Embedding,paddle.nn.layer.Embedding,paddle.nn.layer.common.Embedding
:old_api: paddle.fluid.dygraph.Embedding
**Embedding Layer** **Embedding Layer**
This interface is used to construct a callable object of the ``Embedding`` class. This interface is used to construct a callable object of the ``Embedding`` class.
For specific usage, refer to code examples. It implements the function of the Embedding Layer. For specific usage, refer to code examples. It implements the function of the Embedding Layer.
This layer is used to lookup embeddings vector of ids provided by :attr:`input` . This layer is used to lookup embeddings vector of ids provided by :attr:`x` .
It automatically constructs a 2D embedding matrix based on the It automatically constructs a 2D embedding matrix based on the
input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` . input :attr:`num_embeddings` and attr:`embedding_dim`.
The shape of output Tensor is generated by appending an emb_size dimension to the The shape of output Tensor is generated by appending an emb_size dimension to the
last dimension of the input Tensor shape. last dimension of the input Tensor shape.
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` , **Note:** The id in :attr:`x` must satisfy :math:`0 =< id < num_embeddings` ,
otherwise the program will throw an exception and exit. otherwise the program will throw an exception and exit.
.. code-block:: text .. code-block:: text
...@@ -1607,7 +1603,7 @@ class Embedding(layers.Layer): ...@@ -1607,7 +1603,7 @@ class Embedding(layers.Layer):
num_embeddings (int): Just one element which indicate the size num_embeddings (int): Just one element which indicate the size
of the dictionary of embeddings. of the dictionary of embeddings.
embedding_dim: Just one element which indicate the size of each embedding vector respectively. embedding_dim: Just one element which indicate the size of each embedding vector respectively.
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size). padding_idx(int|long|None): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
...@@ -1618,13 +1614,13 @@ class Embedding(layers.Layer): ...@@ -1618,13 +1614,13 @@ class Embedding(layers.Layer):
such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` ,
:ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` ,
:ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` .
In these case, is_sparse must be False. Default: False. In these case, sparse must be False. Default: False.
weight_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the weight_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the
default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . In addition, default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
The local word vector needs to be transformed into numpy format, and the shape of local word The local word vector needs to be transformed into numpy format, and the shape of local word
vector should be consistent with :attr:`size` . Then :ref:`api_fluid_initializer_NumpyArrayInitializer` vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
is used to load custom or pre-trained word vectors. See code example 2 for details. is used to load custom or pre-trained word vectors. See code example for details.
name(str|None): For detailed information, please refer name(str|None): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
...@@ -1640,19 +1636,33 @@ class Embedding(layers.Layer): ...@@ -1640,19 +1636,33 @@ class Embedding(layers.Layer):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.nn as nn
import numpy as np import numpy as np
paddle.disable_static()
# example 1 x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64') y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
inp_word.shape # [2, 3] paddle.disable_static(paddle.CPUPlace())
dict_size = 20 x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True)
w0=np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0)
adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
adam.clear_grad()
# weight.shape = [10, 3]
# x.data = [[3],[4],[5]]
# x.shape = [3, 1]
# out.data = [[2,2,2], [2,2,2], [2,2,2]]
# out.shape = [3, 1, 3]
out=embedding(x)
out.backward()
adam.step()
emb = nn.Embedding(
dict_size,
32,
sparse=False)
""" """
def __init__(self, def __init__(self,
...@@ -1669,13 +1679,24 @@ class Embedding(layers.Layer): ...@@ -1669,13 +1679,24 @@ class Embedding(layers.Layer):
self._is_distributed = False self._is_distributed = False
self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
num_embeddings + padding_idx) num_embeddings + padding_idx)
if self._num_embeddings <= 0:
raise ValueError("num_embeddings must be gather than 0")
if self._embedding_dim <= 0:
raise ValueError("embedding_dim must be gather than 0")
if self._padding_idx >= num_embeddings or self._padding_idx < -num_embeddings:
raise ValueError("padding_idx must be within [-{}, {})".format(
num_embeddings, num_embeddings))
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._size = [self._num_embeddings, self._embedding_dim] self._size = [self._num_embeddings, self._embedding_dim]
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._remote_prefetch = False self._remote_prefetch = False
self._name = name self._name = name
self._weight = self.create_parameter( self.weight = self.create_parameter(
attr=self._weight_attr, attr=self._weight_attr,
shape=self._size, shape=self._size,
dtype=self._dtype, dtype=self._dtype,
...@@ -1684,7 +1705,7 @@ class Embedding(layers.Layer): ...@@ -1684,7 +1705,7 @@ class Embedding(layers.Layer):
def forward(self, x): def forward(self, x):
return F.embedding( return F.embedding(
x, x,
weight=self._weight, weight=self.weight,
padding_idx=self._padding_idx, padding_idx=self._padding_idx,
sparse=self._sparse, sparse=self._sparse,
name=self._name) name=self._name)
...@@ -627,9 +627,12 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -627,9 +627,12 @@ class KLDivLoss(fluid.dygraph.Layer):
$$l(x, y) = y * (\log(y) - x)$$ $$l(x, y) = y * (\log(y) - x)$$
Parameters: Parameters:
reduction (str, optional): Indicate how to average the loss, reduction (Tensor): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If `reduction` is ``'mean'``, the reduced mean loss is returned;
If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``. Default is ``'mean'``.
Shape: Shape:
...@@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer):
x = np.random.uniform(-10, 10, shape).astype('float32') x = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean') kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_tensor(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_tensor(target)) paddle.to_tensor(target))
# shape=[5] # shape=[1]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean') kldiv_criterion = nn.KLDivLoss(reduction='mean')
...@@ -684,7 +687,7 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -684,7 +687,7 @@ class KLDivLoss(fluid.dygraph.Layer):
self.reduction = reduction self.reduction = reduction
def forward(self, input, label): def forward(self, input, label):
out = paddle.nn.functional.kl_div(input, label, self.reduction) out = F.kl_div(input, label, self.reduction)
return out return out
......
...@@ -183,12 +183,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -183,12 +183,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
x (Tensor): The input tensor could be N-D tensor, and the input data x (Tensor): The input tensor could be N-D tensor, and the input data
type could be float32 or float64. type could be float32 or float64.
p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`, p (float|string, optional): Order of the norm. Supported values are `fro`, `0`, `1`, `2`,
`inf`,`-inf` and any positive real number yielding the corresponding p-norm. `inf`, `-inf` and any positive real number yielding the corresponding p-norm. Not supported: ord < 0 and nuclear norm.
Not supported: ord < 0, nuclear norm. Default value is `fro`.
axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int axis (int|list|tuple, optional): The axis on which to apply norm operation. If axis is int
or list(int)/tuple(int) with only one element, the vector norm is computed over the axis. or list(int)/tuple(int) with only one element, the vector norm is computed over the axis.
If `axis < 0`, the dimension to norm operation is rank(input) + axis. If `axis < 0`, the dimension to norm operation is rank(input) + axis.
If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis. If axis is a list(int)/tuple(int) with two elements, the matrix norm is computed over the axis.
Defalut value is `None`.
keepdim (bool, optional): Whether to reserve the reduced dimension in the keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have fewer dimension output Tensor. The result tensor will have fewer dimension
than the :attr:`input` unless :attr:`keepdim` is true, default than the :attr:`input` unless :attr:`keepdim` is true, default
...@@ -197,13 +198,9 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -197,13 +198,9 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: Tensor, results of norm operation on the specified axis of input tensor, Tensor: results of norm operation on the specified axis of input tensor,
it's data type is the same as input's Tensor. it's data type is the same as input's Tensor.
Raises:
TypeError, if out data type is different with the input data type.
ValueError, If `p` or `axis` is invalid.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -256,15 +253,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -256,15 +253,13 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
"The dim of frobenius norm op should be None or two elements list!" "The dim of frobenius norm op should be None or two elements list!"
) )
if in_dygraph_mode(): if in_dygraph_mode():
if dim is None: dim = [-1] if dim is None:
return core.ops.frobenius_norm(input, 'dim', dim, 'keepdim', return core.ops.frobenius_norm(input, 'keep_dim', keepdim,
keepdim) 'reduce_all', True)
attrs = { return core.ops.frobenius_norm(input, 'dim', dim, 'keep_dim',
'dim': dim if dim != None else [-2, -1], keepdim, 'reduce_all', False)
'keep_dim': keepdim, attrs = {'dim': dim, 'keep_dim': keepdim, 'reduce_all': False}
'reduce_all': False if dim is None:
}
if len(attrs['dim']) == len(input.shape):
attrs['reduce_all'] = True attrs['reduce_all'] = True
check_variable_and_dtype(input, 'input', ['float32', 'float64'], check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'frobenius_norm') 'frobenius_norm')
...@@ -351,42 +346,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -351,42 +346,6 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
return reduce_out return reduce_out
def p0_matrix_norm(input, porder=0., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
cast_out = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': input},
outputs={'Out': cast_out},
attrs={
'in_dtype': input.dtype,
'out_dtype': int(core.VarDesc.VarType.BOOL)
})
cast_out2 = block.create_variable_for_type_inference(dtype=bool)
block.append_op(
type='cast',
inputs={'X': cast_out},
outputs={'Out': cast_out2},
attrs={
'in_dtype': cast_out.dtype,
'out_dtype': int(core.VarDesc.VarType.FP32)
})
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_sum',
inputs={'X': cast_out2},
outputs={'Out': sum_out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False
})
return sum_out
def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None): def p_matrix_norm(input, porder=1., axis=axis, keepdim=False, name=None):
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference( out = block.create_variable_for_type_inference(
...@@ -448,7 +407,20 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -448,7 +407,20 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
#calculate vector norm, where axis is int or list with only one integer #calculate vector norm, where axis is int or list with only one integer
if isinstance(axis, int): if isinstance(axis, int):
if isinstance(p, (int, float)): if isinstance(p, str):
if p == "fro":
return vector_norm(
x,
porder=2,
axis=axis,
keepdim=keepdim,
asvector=False,
name=name)
else:
raise ValueError(
"only valid string values are 'fro', found {}".format(p))
elif isinstance(p, (int, float)):
return vector_norm( return vector_norm(
x, x,
axis=axis, axis=axis,
...@@ -464,10 +436,12 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -464,10 +436,12 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
elif isinstance(axis, list) and len(axis) == 2: elif isinstance(axis, list) and len(axis) == 2:
if p == "fro": if p == "fro":
return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name) return frobenius_norm(x, dim=axis, keepdim=keepdim, name=name)
elif p == 0:
return p0_matrix_norm(x, axis=axis, keepdim=keepdim, name=name)
elif p == np.inf or p == -np.inf: elif p == np.inf or p == -np.inf:
return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name) return inf_norm(x, porder=p, axis=axis, keepdim=keepdim, name=name)
elif p == 0:
raise ValueError(
"just suport axis type int or list (length of list <=1) if p = 0, found {}".
format(axis))
else: else:
return p_matrix_norm( return p_matrix_norm(
x, porder=p, axis=axis, keepdim=keepdim, name=name) x, porder=p, axis=axis, keepdim=keepdim, name=name)
......
...@@ -523,6 +523,24 @@ class TestModelFunction(unittest.TestCase): ...@@ -523,6 +523,24 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=[(20)]) model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2) model.summary(input_size=(20), batch_size=2)
def test_summary_nlp(self):
paddle.enable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))
def test_summary_error(self):
with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, '2'))
with self.assertRaises(ValueError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (-1, -1))
paddle.disable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None fluid.enable_dygraph() if dynamic else None
......
PyGithub
coverage
pycrypto
mock
...@@ -90,12 +90,12 @@ def get_info_file_lines(info_file, diff_file): ...@@ -90,12 +90,12 @@ def get_info_file_lines(info_file, diff_file):
continue continue
elif line.startswith('LF:'): elif line.startswith('LF:'):
print 'LF:{}'.format(current_lf) print('LF:{}'.format(current_lf))
continue continue
elif line.startswith('LH:'): elif line.startswith('LH:'):
print 'LH:{}'.format(current_lh) print('LH:{}'.format(current_lh))
continue continue
......
...@@ -40,7 +40,7 @@ def filter_by(list_file, max_rate): ...@@ -40,7 +40,7 @@ def filter_by(list_file, max_rate):
except: except:
pass pass
print name, rate print(name, rate)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,7 +33,7 @@ def get_lines(info_file): ...@@ -33,7 +33,7 @@ def get_lines(info_file):
hits += 1 hits += 1
if total == 0: if total == 0:
print 'no data found' print('no data found')
exit() exit()
return hits / total return hits / total
...@@ -47,17 +47,17 @@ if __name__ == '__main__': ...@@ -47,17 +47,17 @@ if __name__ == '__main__':
expected = float(sys.argv[2]) expected = float(sys.argv[2])
if not os.path.isfile(info_file): if not os.path.isfile(info_file):
print 'info file {} is not exists, ignored'.format(info_file) print('info file {} is not exists, ignored'.format(info_file))
exit() exit()
actual = get_lines(info_file) actual = get_lines(info_file)
actual = round(actual, 3) actual = round(actual, 3)
if actual < expected: if actual < expected:
print 'expected >= {} %, actual {} %, failed'.format( print('expected >= {} %, actual {} %, failed'.format(
round(expected * 100, 1), round(actual * 100, 1)) round(expected * 100, 1), round(actual * 100, 1)))
exit(1) exit(1)
print 'expected >= {} %, actual {} %, passed'.format( print('expected >= {} %, actual {} %, passed'.format(
round(expected * 100, 1), round(actual * 100, 1)) round(expected * 100, 1), round(actual * 100, 1)))
...@@ -5,7 +5,7 @@ set -xe ...@@ -5,7 +5,7 @@ set -xe
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )" PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
# install lcov # install lcov
curl -o /lcov-1.14.tar.gz -x "" -s https://paddle-ci.gz.bcebos.com/coverage/lcov-1.14.tar.gz curl -o /lcov-1.14.tar.gz -x "" -s https://paddle-ci.gz.bcebos.com/coverage/lcov-1.14.tar.gz || exit 101
tar -xf /lcov-1.14.tar.gz -C / tar -xf /lcov-1.14.tar.gz -C /
cd /lcov-1.14 cd /lcov-1.14
make install make install
...@@ -14,7 +14,7 @@ make install ...@@ -14,7 +14,7 @@ make install
cd /paddle/build cd /paddle/build
python ${PADDLE_ROOT}/tools/coverage/gcda_clean.py ${GIT_PR_ID} python3 ${PADDLE_ROOT}/tools/coverage/gcda_clean.py ${GIT_PR_ID}
lcov --capture -d ./ -o coverage.info --rc lcov_branch_coverage=0 lcov --capture -d ./ -o coverage.info --rc lcov_branch_coverage=0
...@@ -53,9 +53,9 @@ gen_full_html_report || true ...@@ -53,9 +53,9 @@ gen_full_html_report || true
function gen_diff_html_report() { function gen_diff_html_report() {
if [ "${GIT_PR_ID}" != "" ]; then if [ "${GIT_PR_ID}" != "" ]; then
COVERAGE_DIFF_PATTERN="`python ${PADDLE_ROOT}/tools/coverage/pull_request.py files ${GIT_PR_ID}`" COVERAGE_DIFF_PATTERN="`python3 ${PADDLE_ROOT}/tools/coverage/pull_request.py files ${GIT_PR_ID}`"
python ${PADDLE_ROOT}/tools/coverage/pull_request.py diff ${GIT_PR_ID} > git-diff.out python3 ${PADDLE_ROOT}/tools/coverage/pull_request.py diff ${GIT_PR_ID} > git-diff.out
fi fi
lcov --extract coverage-full.info \ lcov --extract coverage-full.info \
...@@ -63,7 +63,7 @@ function gen_diff_html_report() { ...@@ -63,7 +63,7 @@ function gen_diff_html_report() {
-o coverage-diff.info \ -o coverage-diff.info \
--rc lcov_branch_coverage=0 --rc lcov_branch_coverage=0
python ${PADDLE_ROOT}/tools/coverage/coverage_diff.py coverage-diff.info git-diff.out > coverage-diff.tmp python3 ${PADDLE_ROOT}/tools/coverage/coverage_diff.py coverage-diff.info git-diff.out > coverage-diff.tmp
mv -f coverage-diff.tmp coverage-diff.info mv -f coverage-diff.tmp coverage-diff.info
...@@ -82,7 +82,7 @@ set -x ...@@ -82,7 +82,7 @@ set -x
coverage xml -i -o python-coverage.xml coverage xml -i -o python-coverage.xml
python ${PADDLE_ROOT}/tools/coverage/python_coverage.py > python-coverage.info python3 ${PADDLE_ROOT}/tools/coverage/python_coverage.py > python-coverage.info
# python full html report # python full html report
# #
...@@ -143,5 +143,6 @@ echo "Assert Python Diff Coverage" ...@@ -143,5 +143,6 @@ echo "Assert Python Diff Coverage"
python ${PADDLE_ROOT}/tools/coverage/coverage_lines.py python-coverage-diff.info 0.9 || PYTHON_COVERAGE_LINES_ASSERT=1 python ${PADDLE_ROOT}/tools/coverage/coverage_lines.py python-coverage-diff.info 0.9 || PYTHON_COVERAGE_LINES_ASSERT=1
if [ "$COVERAGE_LINES_ASSERT" = "1" ] || [ "$PYTHON_COVERAGE_LINES_ASSERT" = "1" ]; then if [ "$COVERAGE_LINES_ASSERT" = "1" ] || [ "$PYTHON_COVERAGE_LINES_ASSERT" = "1" ]; then
echo "exit 9" > /tmp/paddle_coverage.result
exit 9 exit 9
fi fi
...@@ -40,7 +40,7 @@ def get_files(args): ...@@ -40,7 +40,7 @@ def get_files(args):
pull = get_pull(args.pull_id) pull = get_pull(args.pull_id)
for file in pull.get_files(): for file in pull.get_files():
print '/paddle/{}'.format(file.filename) print('/paddle/{}'.format(file.filename))
def diff(args): def diff(args):
...@@ -55,8 +55,8 @@ def diff(args): ...@@ -55,8 +55,8 @@ def diff(args):
pull = get_pull(args.pull_id) pull = get_pull(args.pull_id)
for file in pull.get_files(): for file in pull.get_files():
print '+++ {}'.format(file.filename) print('+++ {}'.format(file.filename))
print file.patch print(file.patch)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,10 +12,7 @@ root = tree.getroot() ...@@ -12,10 +12,7 @@ root = tree.getroot()
sources = root.findall('sources/source') sources = root.findall('sources/source')
if len(sources) > 1: source = sources[-1].text
exit(1)
source = sources[0].text
for clazz in root.findall('packages/package/classes/class'): for clazz in root.findall('packages/package/classes/class'):
clazz_filename = clazz.attrib.get('filename') clazz_filename = clazz.attrib.get('filename')
...@@ -28,8 +25,8 @@ for clazz in root.findall('packages/package/classes/class'): ...@@ -28,8 +25,8 @@ for clazz in root.findall('packages/package/classes/class'):
if not path.exists(clazz_filename): if not path.exists(clazz_filename):
continue continue
print 'TN:' print('TN:')
print 'SF:{}'.format(clazz_filename) print('SF:{}'.format(clazz_filename))
branch_index = 0 branch_index = 0
...@@ -50,16 +47,16 @@ for clazz in root.findall('packages/package/classes/class'): ...@@ -50,16 +47,16 @@ for clazz in root.findall('packages/package/classes/class'):
taken = int(taken) taken = int(taken)
for _ in range(taken): for _ in range(taken):
print 'BRDA:{},{},{},{}'.format(line_number, 0, branch_index, print('BRDA:{},{},{},{}'.format(line_number, 0, branch_index,
line_hits) line_hits))
branch_index += 1 branch_index += 1
if line_missing_branches: if line_missing_branches:
for missing_branch in line_missing_branches.split(','): for missing_branch in line_missing_branches.split(','):
print 'BRDA:{},{},{},{}'.format(line_number, 0, print('BRDA:{},{},{},{}'.format(line_number, 0,
branch_index, 0) branch_index, 0))
branch_index += 1 branch_index += 1
print 'DA:{},{}'.format(line_number, line_hits) print('DA:{},{}'.format(line_number, line_hits))
print 'end_of_record' print('end_of_record')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册