提交 5ccab2dc 编写于 作者: C chengduoZH

remove conflict

......@@ -2,7 +2,7 @@
|---|---|
| backyes | Yan-Fei Wang |
| beckett1124 | Bin Qi |
| Canpio | Jia-Yi Feng |
| JiayiFeng | Jia-Yi Feng |
| chengxiaohua1105 | Xiao-Hua Cheng |
| cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang |
| cxysteven | Xing-Yi Cheng |
......
......@@ -82,7 +82,7 @@ language = 'zh_CN'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build', '**/*_en*', '*_en*']
exclude_patterns = ['_build', '**/*_en*', '*_en*', 'api/*']
# The reST default role (used for this markup: `text`) to use for all
# documents.
......
......@@ -82,7 +82,7 @@ language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build', '**/*_cn*', '*_cn*']
exclude_patterns = ['_build', '**/*_cn*', '*_cn*', 'api/*']
# The reST default role (used for this markup: `text`) to use for all
# documents.
......
......@@ -11,7 +11,6 @@ if(MOBILE_INFERENCE)
else()
add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(string)
add_subdirectory(scripts)
if(WITH_C_API)
......
......@@ -4,3 +4,4 @@ add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
add_subdirectory(inference)
add_subdirectory(string)
......@@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) {
}
return framework::make_ddim(strides);
}
DDim stride_numel(const framework::DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle
......@@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src);
DDim stride(const DDim& ddim);
DDim stride_numel(const DDim& ddim);
} // namespace framework
} // namespace paddle
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/string/piece.h"
#include "paddle/fluid/string/piece.h"
namespace paddle {
namespace framework {
......
......@@ -37,9 +37,8 @@ class Vector {
// Fill vector with value. The vector size is `count`.
explicit Vector(size_t count, const T& value = T()) {
if (count == 0) {
InitEmpty();
} else {
InitEmpty();
if (count != 0) {
resize(count);
T* ptr = begin();
for (size_t i = 0; i < count; ++i) {
......@@ -122,6 +121,10 @@ class Vector {
const T* begin() const { return &this->operator[](0); }
const T* end() const { return &this->operator[](size()); }
const T* cbegin() const { return begin(); }
const T* cend() const { return end(); }
const T& back() const {
auto it = end();
--it;
......@@ -244,7 +247,9 @@ class Vector {
bool operator==(const Vector<T>& other) const {
if (size() != other.size()) return false;
for (auto it1 = begin(), it2 = other.begin(); it1 < end(); ++it1, ++it2) {
auto it1 = cbegin();
auto it2 = other.cbegin();
for (; it1 < cend(); ++it1, ++it2) {
if (*it1 != *it2) {
return false;
}
......
......@@ -26,10 +26,10 @@ TEST(mixed_vector, CPU_VECTOR) {
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10);
ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10);
ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
......@@ -58,7 +58,7 @@ TEST(mixed_vector, GPU_VECTOR) {
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10);
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu));
......@@ -79,7 +79,7 @@ TEST(mixed_vector, MultiGPU) {
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10);
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu0(0);
paddle::platform::SetDeviceId(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
......@@ -91,3 +91,10 @@ TEST(mixed_vector, MultiGPU) {
ASSERT_EQ(tmp[i], i * 100);
}
}
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/string/printf.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool(benchmark, false,
"Doing memory benchmark. It will make deleting scope synchronized, "
......
......@@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = ins.size();
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0;
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride(out->dims());
for (size_t i = 0; i < n; i++) {
auto& in = ins[i];
auto axis_dim = in->dims()[axis];
auto in_stride = framework::stride(in->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
in->dims(), out_stride, out->data<T>() + output_offset);
output_offset += axis_dim * in_stride[axis];
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride);
output_offset += in_stride[axis];
}
}
};
......@@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
size_t input_offset = 0;
auto in_stride = framework::stride(in->dims());
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
auto in_stride = framework::stride_numel(in->dims());
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
};
......
......@@ -27,7 +27,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -101,11 +101,15 @@ class ListenAndServOp : public framework::OperatorBase {
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
......@@ -126,13 +130,14 @@ class ListenAndServOp : public framework::OperatorBase {
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
if (fan_in > 1 && !param_var_name.empty()) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
......@@ -141,23 +146,35 @@ class ListenAndServOp : public framework::OperatorBase {
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) {
rpc_service_->ShutDown();
}
VLOG(3) << "run optimize graph...";
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(recv_var_cnt);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
sparse_vars.clear();
} // while(true)
}
......
......@@ -38,22 +38,22 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
PADDLE_ENFORCE_EQ(box_dims.size(), 2,
"The rank of Input(BBoxes) must be 2.");
PADDLE_ENFORCE_EQ(box_dims.size(), 3,
"The rank of Input(BBoxes) must be 3.");
PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3.");
PADDLE_ENFORCE_EQ(box_dims[1], 4,
PADDLE_ENFORCE_EQ(box_dims[2], 4,
"The 2nd dimension of Input(BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]");
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2],
PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2],
"The 1st dimensiong of Input(BBoxes) must be equal to "
"3rd dimension of Input(Scores), which represents the "
"predicted bboxes.");
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.
ctx->SetOutputDim("Out", {box_dims[0], 6});
ctx->SetOutputDim("Out", {box_dims[1], 6});
}
protected:
......@@ -260,15 +260,20 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
int64_t batch_size = score_dims[0];
int64_t class_num = score_dims[1];
int64_t predict_dim = score_dims[2];
int64_t box_dim = boxes->dims()[2];
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0};
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = boxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0;
MultiClassNMS(ctx, ins_score, *boxes, indices, num_nmsed_out);
MultiClassNMS(ctx, ins_score, ins_boxes, indices, num_nmsed_out);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
......@@ -282,11 +287,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = boxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
MultiClassOutput(ins_score, *boxes, all_indices[i], &out);
MultiClassOutput(ins_score, ins_boxes, all_indices[i], &out);
}
}
}
......@@ -303,9 +312,9 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("BBoxes",
"(Tensor) A 2-D Tensor with shape [M, 4] represents the "
"predicted locations of M bounding bboxes. Each bounding box "
"has four coordinate values and the layout is "
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the "
"predicted locations of M bounding bboxes, N is the batch size. "
"Each bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax].");
AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
......
......@@ -24,6 +24,22 @@ limitations under the License. */
namespace paddle {
namespace operators {
static bool IsVariableInitialized(const framework::Scope& scope,
const std::string& varname) {
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().IsInitialized();
} else {
PADDLE_THROW(
"Variable type in send side should be in "
"[LodTensor, SelectedRows]");
}
return false;
}
class SendOp : public framework::OperatorBase {
public:
......@@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase {
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
if (IsVariableInitialized(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
PADDLE_ENFORCE(rpc_client->Wait());
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/string/printf.h"
#include "paddle/fluid/string/printf.h"
USE_NO_KERNEL_OP(send);
USE_NO_KERNEL_OP(listen_and_serv);
......
......@@ -29,7 +29,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(ctx->HasInput("Y"));
framework::DDim out_dim;
out_dim = ctx->GetInputDim("Y");
auto y_dim = ctx->GetInputDim("Y");
out_dim = ctx->GetInputDim("X");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out");
ctx->SetOutputDim("Out", out_dim);
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <chrono>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
......@@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride(in->dims());
auto in_stride = framework::stride_numel(in->dims());
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
auto place = ctx.GetPlace();
size_t input_offset = 0;
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
};
......
......@@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable();
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int>({}));
......@@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"SplitSelectedRowsOp must has output Out.");
std::vector<int> height_sections =
ctx->Attrs().Get<std::vector<int>>("height_sections");
int64_t n = ctx->Outputs("Out").size();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(n);
// make output dims
for (int64_t i = 0; i < n; ++i) {
auto dims = ctx->GetInputDim("X");
if (height_sections.size()) {
PADDLE_ENFORCE_EQ(
height_sections.size(), static_cast<size_t>(n),
"The size of height section should be the same with height"
" section size.");
dims[0] = height_sections[i];
}
outs_dims.push_back(dims);
}
ctx->SetOutputsDim("Out", outs_dims);
}
};
......
......@@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(height_sections[i]);
if (rows_idx.size() > 0) {
auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size();
......
......@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
}
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
}
} // namespace operators
} // namespace paddle
......@@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);
if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row,
offset, out);
......
......@@ -22,69 +22,43 @@ class TargetAssignOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
// checkout inputs
PADDLE_ENFORCE(ctx->HasInput("EncodedGTBBox"),
"Input(EncodedGTBBox) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("GTScoreLabel"),
"Input(GTScoreLabel) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("MatchIndices"),
"Input(MatchIndices) of TargetAssignOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("NegIndices"),
"Input(NegIndices) of TargetAssignOp should not be null");
// checkout outputs
PADDLE_ENFORCE(
ctx->HasOutput("PredBBoxLabel"),
"Output(PredBBoxLabel) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredBBoxWeight"),
"Output(PredBBoxWeight) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredScoreLabel"),
"Output(PredScoreLabel) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("PredScoreWeight"),
"Output(PredScoreWeight) of TargetAssignOp should not be null.");
auto blabel_dims = ctx->GetInputDim("EncodedGTBBox");
auto slabel_dims = ctx->GetInputDim("GTScoreLabel");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TargetAssignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutWeight"),
"Output(OutWeight) of TargetAssignOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
auto mi_dims = ctx->GetInputDim("MatchIndices");
auto neg_dims = ctx->GetInputDim("NegIndices");
PADDLE_ENFORCE_EQ(blabel_dims.size(), 3UL,
"The rank of Input(EncodedGTBBox) must be 3.");
PADDLE_ENFORCE_EQ(slabel_dims.size(), 2UL,
"The rank of Input(GTScoreLabel) must be 2.");
PADDLE_ENFORCE_EQ(mi_dims.size(), 2UL,
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "The rank of Input(X) must be 3.");
PADDLE_ENFORCE_EQ(mi_dims.size(), 2,
"The rank of Input(MatchIndices) must be 2.");
PADDLE_ENFORCE_EQ(neg_dims.size(), 2UL,
"The rank of Input(NegIndices) must be 2.");
PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
"The 1st dimension (means the total number of "
"ground-truth bounding boxes) of Input(EncodedGTBBox) "
"and Input(GTScoreLabel) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
"The 2nd dimension (means the number of priod boxes) "
"of Input(EncodedGTBBox) and "
"Input(MatchIndices) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[2], 4,
"The 3rd dimension of Input(EncodedGTBBox) must be 4.");
if (ctx->HasInput("NegIndices")) {
auto neg_dims = ctx->GetInputDim("NegIndices");
PADDLE_ENFORCE_EQ(neg_dims.size(), 2,
"The rank of Input(NegIndices) must be 2.");
PADDLE_ENFORCE_EQ(neg_dims[1], 1,
"The last dimenstion of Out(NegIndices) must be 1.");
}
auto n = mi_dims[0];
auto np = mi_dims[1];
ctx->SetOutputDim("PredBBoxLabel", {n, np, 4});
ctx->SetOutputDim("PredBBoxWeight", {n, np, 1});
ctx->SetOutputDim("PredScoreLabel", {n, np, 1});
ctx->SetOutputDim("PredScoreWeight", {n, np, 1});
auto m = mi_dims[1];
auto k = in_dims[in_dims.size() - 1];
ctx->SetOutputDim("Out", {n, m, k});
ctx->SetOutputDim("OutWeight", {n, m, 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("EncodedGTBBox")->type()),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
......@@ -93,102 +67,87 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("EncodedGTBBox",
"(LoDTensor), The encoded ground-truth bounding boxes with shape "
"[Ng, Np, 4], where Ng is the total number of ground-truth boxes "
"in this mini-batch, Np the number of predictions, 4 is the "
"number of coordinate in [xmin, ymin, xmax, ymax] layout.");
AddInput("GTScoreLabel",
"(LoDTensor, default LoDTensor<int>), The input ground-truth "
"labels with shape [Ng, 1], where the Ng is the same as it in "
"the input of EncodedGTBBox.");
AddInput("X",
"(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. "
"Some elements in X will be assigned to Out based on the "
"MatchIndices and NegIndices.");
AddInput("MatchIndices",
"(Tensor, default Tensor<int>), The input matched indices "
"with shape [N, Np], where N is the batch size, Np is the same "
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the j-th prior box is not matched to any ground-truh "
"box in i-th instance.");
"with shape [N, P], If MatchIndices[i][j] is -1, the j-th entity "
"of column is not matched to any entity of row in i-th instance.");
AddInput("NegIndices",
"(LoDTensor, default LoDTensor<int>), The input negative example "
"indices with shape [Neg, 1], where is the total number of "
"negative example indices.");
AddAttr<int>("background_label",
"(int, default 0), Label index of background class.")
"indices are an optional input with shape [Neg, 1], where Neg is "
"the total number of negative example indices.")
.AsDispensable();
AddAttr<int>("mismatch_value",
"(int, default 0), Fill this value to the "
"mismatched location.")
.SetDefault(0);
AddOutput("PredBBoxLabel",
"(Tensor), The output encoded ground-truth labels "
"with shape [N, Np, 4], N is the batch size and Np, 4 is the "
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
"box for background_label in i-th instance.");
AddOutput("PredBBoxWeight",
"(Tensor), The weight for PredBBoxLabel with the shape "
"of [N, Np, 1]");
AddOutput("PredScoreLabel",
"(Tensor, default Tensor<int>), The output score labels for "
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
"is -1, PredScoreLabel[i][j] = background_label.");
AddOutput("PredScoreWeight",
"(Tensor), The weight for PredScoreLabel with the shape "
"of [N, Np, 1]");
AddOutput("Out",
"(Tensor), The output is a 3D Tensor with shape [N, P, K], "
"N and P is the same as they are in NegIndices, K is the "
"same as it in input of X. If MatchIndices[i][j] "
"is -1, the Out[i][j][0 : K] is the mismatch_value.");
AddOutput("OutWeight",
"(Tensor), The weight for output with the shape of [N, P, 1]");
AddComment(R"DOC(
This operator is, for given the encoded boxes between prior boxes and
ground-truth boxes and ground-truth class labels, to assign classification
and regression targets to each prior box as well as weights to each
prior box. The weights is used to specify which prior box would not contribute
to training loss.
For each instance, the output `PredBBoxLabel`, `PredBBoxWeight`,
`PredScoreLabel` and `PredScoreWeight` are assigned based on `MatchIndices`.
Assumed that the row offset for each instance in `EncodedGTBBox` is called lod,
this operato assigns classification/regression targets by performing the
This operator can be, for given the target bounding boxes or labels,
to assign classification and regression targets to each prediction as well as
weights to prediction. The weights is used to specify which prediction would
not contribute to training loss.
For each instance, the output `Out` and`OutWeight` are assigned based on
`MatchIndices` and `NegIndices`.
Assumed that the row offset for each instance in `X` is called lod,
this operator assigns classification/regression targets by performing the
following steps:
1. Assigning all outpts based on `MatchIndices`:
If id = MatchIndices[i][j] > 0,
PredBBoxLabel[i][j] = EncodedGTBBox[lod[i] + id][j]
PredBBoxWeight[i][j] = 1.
PredScoreLabel[i][j] = GTScoreLabel[lod[i] + id]
PredScoreWeight[i][j] = 1.
Out[i][j][0 : K] = X[lod[i] + id][j % P][0 : K]
OutWeight[i][j] = 1.
Otherwise,
PredBBoxLabel[j][j] = [0., 0., 0., 0.]
PredBBoxWeight[i][j] = 0.
PredScoreLabel[i][j] = background_label
PredScoreWeight[i][j] = 0.
Out[j][j][0 : K] = {mismatch_value, mismatch_value, ...}
OutWeight[i][j] = 0.
2. Assigning PredScoreWeight based on `NegIndices`:
2. Assigning OutWeight based on `NegIndices` if `NegIndices` is provided:
Assumed that the row offset for each instance in `NegIndices` is caleed neg_lod,
for i-th instance and all ids of NegIndices in this instance:
Assumed that the row offset for each instance in `NegIndices` is called neg_lod,
for i-th instance and each `id` of NegIndices in this instance:
PredScoreLabel[i][id] = background_label
PredScoreWeight[i][id] = 1.0
Out[i][id][0 : K] = {mismatch_value, mismatch_value, ...}
OutWeight[i][id] = 1.0
)DOC");
}
};
template <typename T>
struct NegTargetAssignFunctor<platform::CPUDeviceContext, T> {
template <typename T, typename WT>
struct NegTargetAssignFunctor<platform::CPUDeviceContext, T, WT> {
void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices,
const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, T* out_label_wt) {
for (int i = 0; i < num; ++i) {
const size_t* lod, const int N, const int M, const int K,
const int mismatch_value, T* out, WT* out_wt) {
for (int i = 0; i < N; ++i) {
for (size_t j = lod[i]; j < lod[i + 1]; ++j) {
int id = neg_indices[j];
out_label[i * num_prior_box + id] = background_label;
out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0);
int off = (i * M + id) * K;
for (int k = 0; k < K; ++k) {
out[off + k] = mismatch_value;
out_wt[off + k] = static_cast<WT>(1.0);
}
}
}
}
};
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float>;
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, double>;
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, int, float>;
template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float,
float>;
} // namespace operators
} // namespace paddle
......@@ -198,5 +157,5 @@ REGISTER_OP_WITHOUT_GRADIENT(target_assign, ops::TargetAssignOp,
ops::TargetAssignOpMaker);
REGISTER_OP_CPU_KERNEL(
target_assign,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, float>,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, double>);
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, int, float>,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, float, float>);
......@@ -17,39 +17,41 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename WT>
__global__ void NegTargetAssignKernel(const int* neg_indices, const size_t* lod,
const int num, const int num_prior_box,
const int background_label,
int* out_label, T* out_label_wt) {
const int N, const int M, const int K,
const int mismatch_value, T* out,
WT* out_wt) {
int bidx = blockIdx.x;
int st = lod[bidx];
int ed = lod[bidx + 1];
int row_start = bidx * num_prior_box;
int row_start = bidx * M;
for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
int id = row_start + neg_indices[i];
out_label[id] = background_label;
out_label_wt[id] = 1.;
for (int k = 0; k < K; ++k) {
out[id * K + k] = T(mismatch_value);
out_wt[id * K + k] = WT(1.);
}
}
}
template <typename T>
struct NegTargetAssignFunctor<platform::CUDADeviceContext, T> {
template <typename T, typename WT>
struct NegTargetAssignFunctor<platform::CUDADeviceContext, T, WT> {
void operator()(const platform::CUDADeviceContext& ctx,
const int* neg_indices, const size_t* lod, const int num,
const int num_prior_box, const int background_label,
int* out_label, T* out_label_wt) {
const int* neg_indices, const size_t* lod, const int N,
const int M, const int K, const int mismatch_value, T* out,
WT* out_wt) {
const int block_size = 256;
const int grid_size = num;
NegTargetAssignKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
neg_indices, lod, num, num_prior_box, background_label, out_label,
out_label_wt);
const int grid_size = N;
NegTargetAssignKernel<T, WT><<<grid_size, block_size, 0, ctx.stream()>>>(
neg_indices, lod, N, M, K, mismatch_value, out, out_wt);
}
};
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, float>;
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, double>;
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, int, float>;
template struct NegTargetAssignFunctor<platform::CUDADeviceContext, float,
float>;
} // namespace operators
} // namespace paddle
......@@ -57,5 +59,5 @@ template struct NegTargetAssignFunctor<platform::CUDADeviceContext, double>;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
target_assign,
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, float>,
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, double>);
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, int, float>,
ops::TargetAssignKernel<paddle::platform::CUDADeviceContext, float, float>);
......@@ -19,140 +19,113 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
template <typename T, typename WT>
struct TargetAssignFunctor {
const T* gt_box_;
const int* gt_label_;
const T* in_;
const int* match_indices_;
const size_t* lod_;
const int background_label_;
const int64_t num_;
const int64_t num_prior_box_;
T* out_box_;
T* out_box_wt_;
int* out_label_;
T* out_label_wt_;
TargetAssignFunctor(const T* gt_box, const int* gt_label,
const int* match_indices, const size_t* lod,
const int background_label, const int64_t num,
const int64_t np, T* out_box, T* out_box_wt,
int* out_label, T* out_label_wt)
: gt_box_(gt_box),
gt_label_(gt_label),
const int mismatch_value_;
const int64_t N_;
const int64_t M_;
const int64_t P_;
const int64_t K_;
T* out_;
WT* out_wt_;
TargetAssignFunctor(const T* input, const int* match_indices,
const size_t* lod, const int mismatch_value,
const int64_t N, const int64_t M, const int64_t P,
const int64_t K, T* out, WT* out_wt)
: in_(input),
match_indices_(match_indices),
lod_(lod),
background_label_(background_label),
num_(num),
num_prior_box_(np),
out_box_(out_box),
out_box_wt_(out_box_wt),
out_label_(out_label),
out_label_wt_(out_label_wt) {}
mismatch_value_(mismatch_value),
N_(N),
M_(M),
P_(P),
K_(K),
out_(out),
out_wt_(out_wt) {}
HOSTDEVICE void operator()(size_t i) const {
int row = i / num_prior_box_;
int col = i - row * num_prior_box_;
int h = i / M_;
int w = i - h * M_;
size_t row_off = lod_[row];
int offset = row * num_prior_box_ + col;
size_t off = lod_[h];
int id = match_indices_[i];
int id = match_indices_[offset];
T* obox = out_box_ + offset * 4;
int* olabel = out_label_ + offset;
T* obox_wt = out_box_wt_ + offset;
T* olabel_wt = out_label_wt_ + offset;
T* out = out_ + i * K_;
WT* out_wt = out_wt_ + i;
if (id > -1) {
const T* gtbox = gt_box_ + ((row_off + id) * num_prior_box_ + col) * 4;
obox[0] = gtbox[0];
obox[1] = gtbox[1];
obox[2] = gtbox[2];
obox[3] = gtbox[3];
olabel[0] = gt_label_[row_off + id];
obox_wt[0] = static_cast<T>(1.);
olabel_wt[0] = static_cast<T>(1.);
int w_off = w % P_;
const T* in = in_ + ((off + id) * P_ + w_off) * K_;
for (int64_t k = 0; k < K_; ++k) {
out[k] = in[k];
}
out_wt[0] = static_cast<WT>(1.);
} else {
obox[0] = static_cast<T>(0.);
obox[1] = static_cast<T>(0.);
obox[2] = static_cast<T>(0.);
obox[3] = static_cast<T>(0.);
olabel[0] = background_label_;
obox_wt[0] = static_cast<T>(0.);
olabel_wt[0] = static_cast<T>(0.);
for (int64_t k = 0; k < K_; ++k) {
out[k] = static_cast<T>(mismatch_value_);
}
out_wt[0] = static_cast<WT>(0.);
}
}
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, typename WT>
struct NegTargetAssignFunctor {
void operator()(const platform::DeviceContext& ctx, const int* neg_indices,
const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label,
T* out_label_wt) const;
const size_t* lod, const int N, const int M, const int K,
const int mismatch_value, T* out, WT* out_wt) const;
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, typename WT>
class TargetAssignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* enc_gt_box = ctx.Input<framework::LoDTensor>("EncodedGTBBox");
auto* gt_label = ctx.Input<framework::LoDTensor>("GTScoreLabel");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* match_indices = ctx.Input<framework::Tensor>("MatchIndices");
auto* neg_indices = ctx.Input<framework::LoDTensor>("NegIndices");
auto* out_box = ctx.Output<framework::Tensor>("PredBBoxLabel");
auto* out_box_wt = ctx.Output<framework::Tensor>("PredBBoxWeight");
auto* out_label = ctx.Output<framework::Tensor>("PredScoreLabel");
auto* out_label_wt = ctx.Output<framework::Tensor>("PredScoreWeight");
PADDLE_ENFORCE_EQ(enc_gt_box->lod().size(), 1UL);
PADDLE_ENFORCE_EQ(gt_label->lod().size(), 1UL);
PADDLE_ENFORCE_EQ(neg_indices->lod().size(), 1UL);
auto* out = ctx.Output<framework::Tensor>("Out");
auto* out_wt = ctx.Output<framework::Tensor>("OutWeight");
int background_label = ctx.Attr<int>("background_label");
PADDLE_ENFORCE_EQ(x->lod().size(), 1UL);
int mismatch_value = ctx.Attr<int>("mismatch_value");
const T* box_data = enc_gt_box->data<T>();
const int* label_data = gt_label->data<int>();
const T* x_data = x->data<T>();
const int* match_idx_data = match_indices->data<int>();
const int* neg_idx_data = neg_indices->data<int>();
T* obox_data = out_box->mutable_data<T>(ctx.GetPlace());
T* obox_wt_data = out_box_wt->mutable_data<T>(ctx.GetPlace());
int* olabel_data = out_label->mutable_data<int>(ctx.GetPlace());
T* olabel_wt_data = out_label_wt->mutable_data<T>(ctx.GetPlace());
T* out_data = out->mutable_data<T>(ctx.GetPlace());
WT* out_wt_data = out_wt->mutable_data<WT>(ctx.GetPlace());
int64_t num = match_indices->dims()[0];
int64_t num_prior_box = match_indices->dims()[1];
int64_t n = match_indices->dims()[0];
int64_t m = match_indices->dims()[1];
int64_t p = x->dims()[1];
int64_t k = x->dims()[2];
auto gt_lod = enc_gt_box->lod().back();
auto gt_label_lod = gt_label->lod().back();
auto neg_lod = neg_indices->lod().back();
for (size_t i = 0; i < gt_lod.size(); ++i) {
PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
}
size_t* gt_lod_data = gt_lod.MutableData(ctx.GetPlace());
size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
auto x_lod = x->lod().back();
size_t* x_lod_data = x_lod.MutableData(ctx.GetPlace());
TargetAssignFunctor<T> functor(box_data, label_data, match_idx_data,
gt_lod_data, background_label, num,
num_prior_box, obox_data, obox_wt_data,
olabel_data, olabel_wt_data);
TargetAssignFunctor<T, WT> functor(x_data, match_idx_data, x_lod_data,
mismatch_value, n, m, p, k, out_data,
out_wt_data);
auto& device_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(device_ctx,
num * num_prior_box);
platform::ForRange<DeviceContext> for_range(device_ctx, n * m);
for_range(functor);
NegTargetAssignFunctor<DeviceContext, T> neg_trg_functor;
neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box,
background_label, olabel_data, olabel_wt_data);
auto* neg_indices = ctx.Input<framework::LoDTensor>("NegIndices");
if (neg_indices) {
PADDLE_ENFORCE_EQ(neg_indices->lod().size(), 1UL);
const int* neg_idx_data = neg_indices->data<int>();
auto neg_lod = neg_indices->lod().back();
size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace());
NegTargetAssignFunctor<DeviceContext, T, WT> neg_trg_functor;
neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, n, m, k,
mismatch_value, out_data, out_wt_data);
}
}
};
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/string/printf.h"
#include "paddle/fluid/string/printf.h"
#include <ostream>
#include <sstream>
......
......@@ -23,8 +23,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/platform/macros.h"
#include "paddle/string/printf.h"
#include "paddle/string/to_string.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/to_string.h"
#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/string/piece.h"
#include "paddle/fluid/string/piece.h"
using StringPiece = paddle::string::Piece;
using paddle::string::HasPrefix;
......
......@@ -35,7 +35,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/string/to_string.h"
#include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/string/piece.h"
#include "piece.h"
#include <string.h>
......
......@@ -28,7 +28,7 @@ namespace string {
// its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct Pieces and pass them around.
class Piece {
public:
public:
static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can
......@@ -55,7 +55,7 @@ public:
// Return a string that contains the copy of the referenced data.
std::string ToString() const { return std::string(data_, size_); }
private:
private:
const char* data_;
size_t size_;
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/string/piece.h"
#include "paddle/fluid/string/piece.h"
#include <sstream>
......
......@@ -71,7 +71,7 @@
#include <iostream>
#include <sstream>
#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
#include "tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
namespace paddle {
namespace string {
......
......@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/string/printf.h"
#include "printf.h"
#include <string>
......@@ -24,6 +24,6 @@ TEST(StringPrintf, StringPrintf) {
long hour = 14;
int min = 44;
EXPECT_EQ(std::string("Wednesday, July 27, 14:44"),
paddle::string::Sprintf(
"%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min));
paddle::string::Sprintf("%s, %s %d, %.2d:%.2d", weekday, month, day,
hour, min));
}
......@@ -147,7 +147,7 @@ namespace detail {
// Test whether type T1 is convertible to type T2
template <typename T1, typename T2>
struct is_convertible {
private:
private:
// two types of different size
struct fail {
char dummy[2];
......@@ -160,7 +160,7 @@ private:
static succeed tryConvert(const T2 &);
static const T1 &makeT1();
public:
public:
// Standard trick: the (...) version of tryConvert will be chosen from
// the overload set only if the version taking a T2 doesn't match.
// Then we compare the sizes of the return types to check which
......@@ -170,8 +170,7 @@ public:
// Format the value by casting to type fmtT. This default implementation
// should never be called.
template <typename T,
typename fmtT,
template <typename T, typename fmtT,
bool convertible = is_convertible<T, fmtT>::value>
struct formatValueAsType {
static void invoke(std::ostream & /*out*/, const T & /*value*/) { assert(0); }
......@@ -241,11 +240,8 @@ TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(char)
/// operator<< to format the type T, with special cases for the %c and %p
/// conversions.
template <typename T>
inline void formatValue(std::ostream &out,
const char * /*fmtBegin*/,
const char *fmtEnd,
int ntrunc,
const T &value) {
inline void formatValue(std::ostream &out, const char * /*fmtBegin*/,
const char *fmtEnd, int ntrunc, const T &value) {
// The mess here is to support the %c and %p conversions: if these
// conversions are active we try to convert the type to a char or const
// void* respectively and format that instead of the value itself. For the
......@@ -267,25 +263,22 @@ inline void formatValue(std::ostream &out,
}
// Overloaded version for char types to support printing as an integer
#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \
inline void formatValue(std::ostream &out, \
const char * /*fmtBegin*/, \
const char *fmtEnd, \
int /**/, \
charType value) { \
switch (*(fmtEnd - 1)) { \
case 'u': \
case 'd': \
case 'i': \
case 'o': \
case 'X': \
case 'x': \
out << static_cast<int>(value); \
break; \
default: \
out << value; \
break; \
} \
#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \
inline void formatValue(std::ostream &out, const char * /*fmtBegin*/, \
const char *fmtEnd, int /**/, charType value) { \
switch (*(fmtEnd - 1)) { \
case 'u': \
case 'd': \
case 'i': \
case 'o': \
case 'X': \
case 'x': \
out << static_cast<int>(value); \
break; \
default: \
out << value; \
break; \
} \
}
// per 3.9.1: char, signed char and unsigned char are all distinct types
TINYFORMAT_DEFINE_FORMATVALUE_CHAR(char)
......@@ -482,7 +475,7 @@ namespace detail {
// each argument to be allocated as a homogenous array inside FormatList
// whereas a naive implementation based on inheritance does not.
class FormatArg {
public:
public:
FormatArg() {}
template <typename T>
......@@ -491,22 +484,17 @@ public:
m_formatImpl(&formatImpl<T>),
m_toIntImpl(&toIntImpl<T>) {}
void format(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
void format(std::ostream &out, const char *fmtBegin, const char *fmtEnd,
int ntrunc) const {
m_formatImpl(out, fmtBegin, fmtEnd, ntrunc, m_value);
}
int toInt() const { return m_toIntImpl(m_value); }
private:
private:
template <typename T>
static void formatImpl(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
int ntrunc,
const void *value) {
static void formatImpl(std::ostream &out, const char *fmtBegin,
const char *fmtEnd, int ntrunc, const void *value) {
formatValue(out, fmtBegin, fmtEnd, ntrunc, *static_cast<const T *>(value));
}
......@@ -516,11 +504,8 @@ private:
}
const void *m_value;
void (*m_formatImpl)(std::ostream &out,
const char *fmtBegin,
const char *fmtEnd,
int ntrunc,
const void *value);
void (*m_formatImpl)(std::ostream &out, const char *fmtBegin,
const char *fmtEnd, int ntrunc, const void *value);
int (*m_toIntImpl)(const void *value);
};
......@@ -569,12 +554,10 @@ inline const char *printFormatStringLiteral(std::ostream &out,
// necessary to pull out variable width and precision . The function returns a
// pointer to the character after the end of the current format spec.
inline const char *streamStateFromFormat(std::ostream &out,
bool &spacePadPositive,
int &ntrunc,
bool &spacePadPositive, int &ntrunc,
const char *fmtStart,
const detail::FormatArg *formatters,
int &argIndex,
int numFormatters) {
int &argIndex, int numFormatters) {
if (*fmtStart != '%') {
TINYFORMAT_ERROR(
"tinyformat: Not enough conversion specifiers in format string");
......@@ -750,10 +733,8 @@ inline const char *streamStateFromFormat(std::ostream &out,
}
//------------------------------------------------------------------------------
inline void formatImpl(std::ostream &out,
const char *fmt,
const detail::FormatArg *formatters,
int numFormatters) {
inline void formatImpl(std::ostream &out, const char *fmt,
const detail::FormatArg *formatters, int numFormatters) {
// Saved stream state
std::streamsize origWidth = out.width();
std::streamsize origPrecision = out.precision();
......@@ -765,13 +746,9 @@ inline void formatImpl(std::ostream &out,
fmt = printFormatStringLiteral(out, fmt);
bool spacePadPositive = false;
int ntrunc = -1;
const char *fmtEnd = streamStateFromFormat(out,
spacePadPositive,
ntrunc,
fmt,
formatters,
argIndex,
numFormatters);
const char *fmtEnd =
streamStateFromFormat(out, spacePadPositive, ntrunc, fmt, formatters,
argIndex, numFormatters);
if (argIndex >= numFormatters) {
// Check args remain after reading any variable width/precision
TINYFORMAT_ERROR("tinyformat: Not enough format arguments");
......@@ -820,15 +797,14 @@ inline void formatImpl(std::ostream &out,
/// information has been stripped from the arguments, leaving just enough of a
/// common interface to perform formatting as required.
class FormatList {
public:
public:
FormatList(detail::FormatArg *formatters, int N)
: m_formatters(formatters), m_N(N) {}
friend void vformat(std::ostream &out,
const char *fmt,
friend void vformat(std::ostream &out, const char *fmt,
const FormatList &list);
private:
private:
const detail::FormatArg *m_formatters;
int m_N;
};
......@@ -841,7 +817,7 @@ namespace detail {
// Format list subclass with fixed storage to avoid dynamic allocation
template <int N>
class FormatListN : public FormatList {
public:
public:
template <typename... Args>
FormatListN(const Args &... args)
: FormatList(&m_formatterStore[0], N),
......@@ -849,14 +825,14 @@ public:
static_assert(sizeof...(args) == N, "Number of args must be N");
}
private:
private:
FormatArg m_formatterStore[N];
};
// Special 0-arg version - MSVC says zero-sized C array in struct is nonstandard
template <>
class FormatListN<0> : public FormatList {
public:
public:
FormatListN() : FormatList(0, 0) {}
};
......
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/string/to_string.h"
#include "to_string.h"
#include <gtest/gtest.h>
constexpr char kOutputString[] = "User Defined Output";
class UserDefinedClass {
public:
public:
};
std::ostream& operator<<(std::ostream& s, const UserDefinedClass& ins) {
......
......@@ -115,8 +115,8 @@ EOF
-DWITH_AVX=${WITH_AVX:-ON} \
-DWITH_SWIG_PY=ON \
-DWITH_STYLE_CHECK=OFF
make -j `nproc` gen_proto_py
make -j `nproc` paddle_python
make -j `nproc` gen_proto_py framework_py_proto
make -j `nproc` copy_paddle_pybind
make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs
popd
fi
......
......@@ -6,9 +6,9 @@ mkdir -p $TRAVIS_BUILD_DIR/build
cd $TRAVIS_BUILD_DIR/build
# Compile Documentation only.
cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON
make -j `nproc` gen_proto_py
make -j `nproc` paddle_python
cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON -DWITH_STYLE_CHECK=OFF
make -j `nproc` gen_proto_py framework_py_proto
make -j `nproc` copy_paddle_pybind
make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs
# check websites for broken links
......
......@@ -33,6 +33,57 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size)
class UnionFind(object):
""" Union-find data struct.
Union-find is a data struct that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
......@@ -140,6 +191,7 @@ class DistributeTranspiler:
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
......@@ -178,6 +230,21 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]},
attrs={"axis": 0})
self.lr_param_mapping = self._create_lr_param_mapping()
def _create_lr_param_mapping(self):
lr_mapping = dict()
for _, opt_op in enumerate(self.optimize_ops):
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
or not opt_op.inputs.has_key("Param"):
continue
lr = opt_op.inputs["LearningRate"].name
param = opt_op.inputs["Param"].name
if not lr_mapping.has_key(lr):
lr_mapping.update({lr: list()})
lr_mapping[lr].append(param)
return lr_mapping
def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list
block_map = dict()
......@@ -208,6 +275,7 @@ class DistributeTranspiler:
name="%s.block%d" % (varname, i),
psersistable=False,
dtype=orig_var.dtype,
type=orig_var.type,
shape=splited_shape) # flattend splited var
var_mapping[varname].append(var)
return var_mapping
......@@ -269,6 +337,7 @@ class DistributeTranspiler:
name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable,
dtype=var.dtype,
type=var.type,
shape=var.shape)
var_list.append(var_each)
return var_list
......@@ -300,52 +369,15 @@ class DistributeTranspiler:
pass
return orig_shape
def _op_input_var(self, op, varname):
pass
def _is_op_on_pserver(self, endpoint, all_ops, idx):
"""
Recursively check if the op need to run on current server.
Assume that ops are in the execution order.
"""
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
input_names = set(op.input_names)
# TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True
else:
for n in param_names:
if same_or_split_var(n, op.input("Param")[0]) \
and n != op.input("Param")[0]:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
# prev_output_names = [o.name for o in prev_op.outputs.values()]
# prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False
found2 = False
for varname in op.desc.input_arg_names():
if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for varname in op.desc.output_arg_names():
if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
j -= 1
return False
def _fetch_var_names(self, param_dict):
res = []
if not param_dict:
return res
for _, values in param_dict.iteritems():
if not isinstance(values, list):
values = [values]
res += [v.name for v in values]
return res
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program
......@@ -363,11 +395,7 @@ class DistributeTranspiler:
# do not append this op if current endpoint
# is not dealing with this grad block
return
merged_var = program.global_block().create_var(
name=grad_block.name,
persistable=grad_block.persistable,
dtype=grad_block.dtype,
shape=grad_block.shape)
merged_var = program.global_block().vars[grad_block.name]
# append merging ops if trainers > 1
if self.trainers > 1:
vars2merge = self._create_var_for_trainers(
......@@ -398,13 +426,19 @@ class DistributeTranspiler:
shape=param_block.shape)
new_inputs[key] = tmpvar
elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
new_inputs[key] = program.global_block().vars[opt_op.input(key)[
0]]
for key in opt_op.input_names:
if key in ["Param", "Grad"]:
new_shape = None
if key in ["Param", "Grad", "LearningRate"]:
continue
var = program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
......@@ -415,12 +449,11 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar
# change output's ParamOut variable
outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
opt_op.outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=outputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
......@@ -428,11 +461,10 @@ class DistributeTranspiler:
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op)
for var in inputs.itervalues():
if type(var) == list:
varlist = var
else:
varlist = [var]
for varlist in inputs.itervalues():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if not program.global_block().vars.has_key(var.name):
program.global_block().create_var(
......@@ -444,12 +476,70 @@ class DistributeTranspiler:
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
for varlist in outputs.itervalues():
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
optimize_block.append_op(
type=opt_op.type,
inputs=inputs,
outputs=outputs,
attrs=opt_op.attrs)
def _is_op_connected(self, op1, op2):
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
op1_input_names = self._fetch_var_names(op1.inputs)
op1_output_names = self._fetch_var_names(op1.outputs)
op2_input_names = self._fetch_var_names(op2.inputs)
op2_output_names = self._fetch_var_names(op2.outputs)
if set(op1_output_names) & set(op2_input_names) or \
set(op1_input_names) & set(op2_output_names):
return True
return False
def _create_ufind(self, optimize_ops):
# Create a unit find data struct by optimize ops
ufind = UnionFind(optimize_ops)
for i in xrange(len(optimize_ops)):
for j in xrange(i, len(optimize_ops)):
op1 = optimize_ops[i]
op2 = optimize_ops[j]
if self._is_op_connected(op1, op2):
ufind.union(op1, op2)
return ufind
def _is_opt_op(self, op):
# NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if op.inputs and op.inputs.has_key("Param") \
and op.inputs.has_key("LearningRate"):
return True
return False
def _is_opt_op_on_pserver(self, endpoint, op):
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
if op.inputs["Param"].name in param_names:
return True
else:
for n in param_names:
param = op.inputs["Param"].name
if same_or_split_var(n, param) and n != op.inputs["Param"].name:
return True
return False
return False
def get_pserver_program(self, endpoint):
"""
Get pserver side program using the endpoint
......@@ -469,26 +559,38 @@ class DistributeTranspiler:
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_block = pserver_program.create_block(0)
# Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint,
self.optimize_ops, idx)
if not is_op_on_pserver:
continue
if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_block, opt_op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_block, opt_op)
# step 6.1
# Create a union-find data struct by optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind = self._create_ufind(self.optimize_ops)
# step 6.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op)
# step 6.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
# Append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
......
......@@ -16,6 +16,8 @@ import ops
from ops import *
import nn
from nn import *
import detection
from detection import *
import io
from io import *
import tensor
......@@ -31,6 +33,7 @@ from detection import *
__all__ = []
__all__ += math_op_patch.__all__
__all__ += detection.__all__
__all__ += nn.__all__
__all__ += io.__all__
__all__ += tensor.__all__
......
......@@ -22,7 +22,106 @@ from ops import reshape
from operator import mul
import math
__all__ = ['prior_box', ]
__all__ = [
'detection_output',
'prior_box',
]
def detection_output(scores,
loc,
prior_box,
prior_box_var,
background_label=0,
nms_threshold=0.3,
nms_top_k=400,
keep_top_k=200,
score_threshold=0.01,
nms_eta=1.0):
"""
**Detection Output Layer**
This layer applies the NMS to the output of network and computes the
predict bounding box location. The output's shape of this layer could
be zero if there is no valid bounding box.
Args:
scores(Variable): A 3-D Tensor with shape [N, C, M] represents the
predicted confidence predictions. N is the batch size, C is the
class number, M is number of bounding boxes. For each category
there are total M scores which corresponding M bounding boxes.
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
prior_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
if the input is image feature map, they are close to the origin
of the coordinate system. [xmax, ymax] is the right bottom
coordinate of the anchor box.
prior_box_var(Variable): A 2-D Tensor with shape [M, 4] holds M group
of variance.
background_label(float): The index of background label,
the background label will be ignored. If set to -1, then all
categories will be considered.
nms_threshold(float): The threshold to be used in NMS.
nms_top_k(int): Maximum number of detections to be kept according
to the confidences aftern the filtering detections based on
score_threshold.
keep_top_k(int): Number of total bboxes to be kept per image after
NMS step. -1 means keeping all bboxes after NMS step.
score_threshold(float): Threshold to filter out bounding boxes with
low confidence score. If not provided, consider all boxes.
nms_eta(float): The parameter for adaptive NMS.
Returns:
The detected bounding boxes which are a Tensor.
Examples:
.. code-block:: python
pb = layers.data(name='prior_box', shape=[10, 4],
append_batch_size=False, dtype='float32')
pbv = layers.data(name='prior_box_var', shape=[10, 4],
append_batch_size=False, dtype='float32')
loc = layers.data(name='target_box', shape=[21, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 21, 10],
append_batch_size=False, dtype='float32')
nmsed_outs = fluid.layers.detection_output(scores=scores,
loc=loc,
prior_box=pb,
prior_box_var=pbv)
"""
helper = LayerHelper("detection_output", **locals())
decoded_box = helper.create_tmp_variable(dtype=loc.dtype)
helper.append_op(
type="box_coder",
inputs={
'PriorBox': prior_box,
'PriorBoxVar': prior_box_var,
'TargetBox': loc
},
outputs={'OutputBox': decoded_box},
attrs={'code_type': 'decode_center_size'})
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
'BBoxes': decoded_box},
outputs={'Out': nmsed_outs},
attrs={
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0
})
return nmsed_outs
def prior_box(inputs,
......@@ -47,7 +146,7 @@ def prior_box(inputs,
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
The details of this algorithm, please refer the section 2.2 of SSD paper
(SSD: Single Shot MultiBox Detector)<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list): The list of input Variables, the format of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp, the layout is NCHW.
......@@ -73,7 +172,7 @@ def prior_box(inputs,
max_sizes(list, optional, default=None): If `len(inputs) <=2`, max_sizes must
be set up, and the length of min_sizes should equal to the length of inputs.
name(str, optional, None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBoxOp. The layout is
[num_priors, 4]. num_priors is the total box count of each
......@@ -81,20 +180,19 @@ def prior_box(inputs,
Variances(Variable): the expanded variances of PriorBoxOp. The layout
is [num_priors, 4]. num_priors is the total box count of each
position of inputs
Examples:
.. code-block:: python
prior_boxes(
prior_box(
inputs = [conv1, conv2, conv3, conv4, conv5, conv6],
image = data,
min_ratio = 20, # 0.20
max_ratio = 90, # 0.90
steps = [8., 16., 32., 64., 100., 300.],
aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size = 300,
offset = 0.5,
base_size = 300,
variance = [0.1,0.1,0.1,0.1],
aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
flip=True,
clip=True)
"""
......
......@@ -117,6 +117,7 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op(
type=op_type,
inputs={'X': [self],
......
......@@ -99,7 +99,7 @@ elif training_role == "TRAINER":
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_cost_np = exe.run(fluid.default_main_program(),
avg_cost_np = exe.run(t.get_trainer_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np)
......
......@@ -64,9 +64,7 @@ exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
test_pass = res1.shape == (10, 2) and res2.shape == (10, 1)
if not test_pass:
if not (res1.shape == (10, 2) and res2.shape == (10, 1)):
exit(1)
exit(0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.layers.detection as detection
from paddle.v2.fluid.framework import Program, program_guard
import unittest
import numpy as np
class TestBook(unittest.TestCase):
def test_detection_output(self):
program = Program()
with program_guard(program):
pb = layers.data(
name='prior_box',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
pbv = layers.data(
name='prior_box_var',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
loc = layers.data(
name='target_box',
shape=[20, 4],
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=[2, 20, 10],
append_batch_size=False,
dtype='float32')
out = layers.detection_output(
scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv)
self.assertIsNotNone(out)
print(str(program))
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
self.check_prior_box(use_cuda=False)
self.check_prior_box(use_cuda=True)
def prior_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
min_ratio=20,
max_ratio=90,
# steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
return box, var
def check_prior_box(self, use_cuda):
if use_cuda: # prior_box only support CPU.
return
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x, place)
boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x},
fetch_list=[box, var])
assert vars.shape == var.shape
assert boxes.shape == box.shape
if __name__ == '__main__':
unittest.main()
......@@ -137,7 +137,7 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold,
det_outs = []
lod = [0]
for n in range(batch_size):
nmsed_outs, nmsed_num = multiclass_nms(boxes, scores[n], background,
nmsed_outs, nmsed_num = multiclass_nms(boxes[n], scores[n], background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
lod.append(lod[-1] + nmsed_num)
......@@ -145,7 +145,7 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold,
for c, indices in nmsed_outs.iteritems():
for idx in indices:
xmin, ymin, xmax, ymax = boxes[idx][:]
xmin, ymin, xmax, ymax = boxes[n][idx][:]
det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])
return det_outs, lod
......@@ -179,9 +179,9 @@ class TestMulticlassNMSOp(OpTest):
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((M, BOX_SIZE)).astype('float32')
boxes[:, 0:2] = boxes[:, 0:2] * 0.5
boxes[:, 2:4] = boxes[:, 2:4] * 0.5 + 0.5
boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers.detection as detection
import paddle.v2.fluid.core as core
import unittest
def prior_box_output(data_shape):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
min_ratio=20,
max_ratio=90,
# steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
return box, var
def main(use_cuda):
if use_cuda: # prior_box only support CPU.
return
data_shape = [3, 224, 224]
box, var = prior_box_output(data_shape)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
for _ in range(1):
x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x, place)
boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x},
fetch_list=[box, var])
assert vars.shape == var.shape
assert boxes.shape == box.shape
class TestFitALine(unittest.TestCase):
def test_cpu(self):
main(use_cuda=False)
def test_cuda(self):
main(use_cuda=True)
if __name__ == '__main__':
unittest.main()
......@@ -73,5 +73,20 @@ class TestSequenceExpandCase3(TestSequenceExpand):
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self):
x_data = np.array(
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape(
[2, 5]).astype('float32')
x_lod = [[
0,
1,
2,
]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
if __name__ == '__main__':
unittest.main()
......@@ -20,11 +20,11 @@ from op_test import OpTest
class TestSplitOp(OpTest):
def setUp(self):
self.op_type = "split"
axis = 0
x = np.random.random((4, 2, 5)).astype('float32')
out = np.split(x, [1, 3], axis)
axis = 1
x = np.random.random((4, 5, 6)).astype('float32')
out = np.split(x, [2, 3], axis)
self.inputs = {'X': x}
self.attrs = {'axis': axis, 'sections': [1, 2, 1]}
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in xrange(len(out))]}
......
......@@ -43,7 +43,7 @@ def gen_match_and_neg_indices(num_prior, gt_lod, neg_lod):
def target_assign(encoded_box, gt_label, match_indices, neg_indices, gt_lod,
neg_lod, background_label):
neg_lod, mismatch_value):
batch_size, num_prior = match_indices.shape
# init target bbox
......@@ -52,7 +52,7 @@ def target_assign(encoded_box, gt_label, match_indices, neg_indices, gt_lod,
trg_box_wt = np.zeros((batch_size, num_prior, 1)).astype('float32')
# init target label
trg_label = np.ones((batch_size, num_prior, 1)).astype('int32')
trg_label = trg_label * background_label
trg_label = trg_label * mismatch_value
# init weight for target label
trg_label_wt = np.zeros((batch_size, num_prior, 1)).astype('float32')
......@@ -65,53 +65,90 @@ def target_assign(encoded_box, gt_label, match_indices, neg_indices, gt_lod,
# target bbox
for v, c in zip(col_val + gt_start, col_ids[0].tolist()):
trg_box[i][c][:] = encoded_box[v][c][:]
# weight for target bbox
trg_box_wt[i][col_ids] = 1.0
trg_label[i][col_ids] = gt_label[col_val + gt_start]
trg_label_wt[i][col_ids] = 1.0
# set target label weight to 1.0 for the negative samples
neg_ids = neg_indices[neg_lod[i]:neg_lod[i + 1]]
trg_label_wt[i][neg_ids] = 1.0
if neg_indices is not None:
neg_ids = neg_indices[neg_lod[i]:neg_lod[i + 1]]
trg_label_wt[i][neg_ids] = 1.0
return trg_box, trg_box_wt, trg_label, trg_label_wt
class TestTargetAssginOp(OpTest):
class TestTargetAssginFloatType(OpTest):
def setUp(self):
self.op_type = "target_assign"
num_prior = 120
num_class = 21
gt_lod = [0, 5, 11, 23]
neg_lod = [0, 4, 7, 13]
mismatch_value = 0
batch_size = len(gt_lod) - 1
num_gt = gt_lod[-1]
encoded_box = np.random.random((num_gt, num_prior, 4)).astype('float32')
gt_label = np.random.randint(
num_class, size=(num_gt, 1)).astype('int32')
match_indices, neg_indices = gen_match_and_neg_indices(num_prior,
gt_lod, neg_lod)
out, out_wt, _, _ = target_assign(encoded_box, gt_label, match_indices,
neg_indices, gt_lod, neg_lod,
mismatch_value)
# assign regression targets
x = encoded_box
self.inputs = {
'X': (x, [gt_lod]),
'MatchIndices': match_indices,
}
self.attrs = {'mismatch_value': mismatch_value}
self.outputs = {
'Out': out,
'OutWeight': out_wt,
}
def test_check_output(self):
self.check_output()
class TestTargetAssginIntType(OpTest):
def setUp(self):
self.op_type = "target_assign"
num_prior = 120
num_class = 21
gt_lod = [0, 5, 11, 23]
neg_lod = [0, 4, 7, 13]
mismatch_value = 0
batch_size = len(gt_lod) - 1
num_gt = gt_lod[-1]
background_label = 0
encoded_box = np.random.random((num_gt, num_prior, 4)).astype('float32')
gt_label = np.random.randint(
num_class, size=(num_gt, 1)).astype('int32')
match_indices, neg_indices = gen_match_and_neg_indices(num_prior,
gt_lod, neg_lod)
trg_box, trg_box_wt, trg_label, trg_label_wt = target_assign(
encoded_box, gt_label, match_indices, neg_indices, gt_lod, neg_lod,
background_label)
_, _, out, out_wt, = target_assign(encoded_box, gt_label, match_indices,
neg_indices, gt_lod, neg_lod,
mismatch_value)
# assign cassification argets
x = np.reshape(gt_label, (num_gt, 1, 1))
self.inputs = {
'EncodedGTBBox': (encoded_box, [gt_lod]),
'GTScoreLabel': (gt_label, [gt_lod]),
'MatchIndices': (match_indices),
'X': (x, [gt_lod]),
'MatchIndices': match_indices,
'NegIndices': (neg_indices, [neg_lod]),
}
self.attrs = {'background_label': background_label}
self.attrs = {'mismatch_value': mismatch_value}
self.outputs = {
'PredBBoxLabel': (trg_box),
'PredBBoxWeight': (trg_box_wt),
'PredScoreLabel': (trg_label),
'PredScoreWeight': (trg_label_wt),
'Out': out,
'OutWeight': out_wt,
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册