提交 8ea2288e 编写于 作者: C chengduoZH

remove conflict

...@@ -16,12 +16,10 @@ function(copy TARGET) ...@@ -16,12 +16,10 @@ function(copy TARGET)
foreach(index RANGE ${len}) foreach(index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}") add_custom_command(TARGET ${TARGET} PRE_BUILD
if(IS_DIRECTORY ${src}) COMMAND mkdir -p "${dst}"
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp -r "${src}" "${dst}") COMMAND cp -r "${src}" "${dst}"
else() COMMENT "copying ${src} -> ${dst}")
add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND cp "${src}" "${dst}")
endif()
endforeach() endforeach()
endfunction() endfunction()
...@@ -53,11 +51,11 @@ IF(NOT PROTOBUF_FOUND) ...@@ -53,11 +51,11 @@ IF(NOT PROTOBUF_FOUND)
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
# paddle fluid module # paddle fluid module
set(src_dir "${PADDLE_SOURCE_DIR}/paddle") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle") set(dst_dir "${CMAKE_INSTALL_PREFIX}/paddle/fluid")
set(module "framework") set(module "framework")
copy(framework_lib DEPS framework_py_proto copy(framework_lib DEPS framework_py_proto
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/framework/framework.pb.h SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module}
) )
...@@ -69,7 +67,7 @@ copy(memory_lib ...@@ -69,7 +67,7 @@ copy(memory_lib
set(module "inference") set(module "inference")
copy(inference_lib DEPENDS paddle_fluid_shared copy(inference_lib DEPENDS paddle_fluid_shared
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/inference/libpaddle_fluid.so SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.so
DSTS ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module}
) )
......
...@@ -25,7 +25,10 @@ namespace framework { ...@@ -25,7 +25,10 @@ namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......
...@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
RunImpl(scope, place);
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
......
...@@ -89,8 +89,9 @@ class OperatorBase { ...@@ -89,8 +89,9 @@ class OperatorBase {
std::string DebugString() const { return DebugStringEx(nullptr); } std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op. /// Net will call this interface function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; // The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop. // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {} virtual void Stop() {}
...@@ -144,6 +145,8 @@ class OperatorBase { ...@@ -144,6 +145,8 @@ class OperatorBase {
private: private:
void GenerateTemporaryNames(); void GenerateTemporaryNames();
void CheckAllInputOutputSet() const; void CheckAllInputOutputSet() const;
virtual void RunImpl(const Scope& scope,
const platform::Place& place) const = 0;
}; };
// Macro for define a clone method. // Macro for define a clone method.
...@@ -168,10 +171,13 @@ class OperatorBase { ...@@ -168,10 +171,13 @@ class OperatorBase {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
std::unique_ptr<OperatorBase> Clone() const override { std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this)); return std::unique_ptr<OperatorBase>(new NOP(*this));
} }
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
}; };
class ExecutionContext { class ExecutionContext {
...@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const Scope& scope, const platform::Place& place) const final;
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
...@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {} : OperatorBase(type, inputs, outputs, attrs), x(1) {}
void Run(const Scope& scope, const platform::Place& place) const override {
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {
++op_run_num; ++op_run_num;
ASSERT_EQ(static_cast<int>(inputs_.size()), 1); ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ(static_cast<int>(outputs_.size()), 1); ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
...@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase { ...@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::VariableNameMap& outputs, const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs) const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {} private:
void RunImpl(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {}
}; };
TEST(Operator, Clone) { TEST(Operator, Clone) {
......
...@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
......
...@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase { ...@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) { if (x == nullptr) {
return; return;
......
...@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
......
...@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase { ...@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
void Run(const framework::Scope& scope, private:
const platform::Place& dev_place) const override { void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids")); auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores")); auto scores_var = scope.FindVar(Input("scores"));
auto pre_ids_var = scope.FindVar(Input("pre_ids")); auto pre_ids_var = scope.FindVar(Input("pre_ids"));
......
...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride, out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride); in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} }
...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
} }
} }
void CondOp::Run(const Scope& scope, const platform::Place& place) const { void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place); auto& dev_ctx = *pool.Get(place);
......
...@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase { ...@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
sub_net_op_[FALSE_BRANCH] = std::move(net); sub_net_op_[FALSE_BRANCH] = std::move(net);
} }
void Run(const framework::Scope& scope, private:
const platform::Place& place) const override; void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override;
private: private:
const int TRUE_BRANCH = 0; const int TRUE_BRANCH = 0;
......
...@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -65,8 +65,10 @@ class ConditionalBlockOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope); auto xs = InputTensors(scope);
bool need_run; bool need_run;
...@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -128,8 +130,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {} : ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope); auto xs = this->InputTensors(scope);
bool need_run; bool need_run;
......
...@@ -106,8 +106,10 @@ template <typename T> ...@@ -106,8 +106,10 @@ template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase { class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
...@@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker ...@@ -155,8 +157,10 @@ class CreateRandomDataGeneratorOpMaker
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
...@@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -187,8 +191,10 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class CreateBatchReaderOp : public framework::OperatorBase { class CreateBatchReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
......
...@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase { ...@@ -24,8 +24,10 @@ class FeedOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
......
...@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase { ...@@ -26,8 +26,9 @@ class FetchOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr, PADDLE_ENFORCE(fetch_var != nullptr,
......
...@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase { ...@@ -33,8 +33,10 @@ class FillConstantInferShape : public framework::InferShapeBase {
class FillConstantOp : public framework::OperatorBase { class FillConstantOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
......
...@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase { ...@@ -42,8 +42,10 @@ class FillOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &out = auto &out =
detail::Ref(detail::Ref(scope.FindVar(Output("Out")), detail::Ref(detail::Ref(scope.FindVar(Output("Out")),
"Cannot find variable %s", Output("Out")) "Cannot find variable %s", Output("Out"))
......
...@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -37,8 +37,10 @@ class GetPlacesOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
bool is_gpu; bool is_gpu;
if (Attr<std::string>("device_type") == "AUTO") { if (Attr<std::string>("device_type") == "AUTO") {
is_gpu = platform::is_gpu_place(place); is_gpu = platform::is_gpu_place(place);
......
...@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase { ...@@ -51,8 +51,9 @@ class IncrementOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
...@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase { ...@@ -28,8 +28,9 @@ class IsEmptyOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get input // get input
auto *var = scope.FindVar(Input(kInput)); auto *var = scope.FindVar(Input(kInput));
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
......
...@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -26,8 +26,10 @@ class LoadCombineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
......
...@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase { ...@@ -25,8 +25,10 @@ class LoadOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
......
...@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase { ...@@ -25,8 +25,10 @@ class LoDArrayLengthOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out = auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
......
...@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase { ...@@ -23,8 +23,10 @@ class LoDRankTableOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto *out = auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
......
...@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -32,8 +32,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s",
Input("X")) Input("X"))
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
......
...@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase { ...@@ -27,8 +27,9 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &dev_place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &rank_table = auto &rank_table =
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>(); scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
auto *out = auto *out =
......
...@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -27,8 +27,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
......
...@@ -237,6 +237,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -237,6 +237,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("UpdatedMatchIndices", idx_dims); ctx->SetOutputDim("UpdatedMatchIndices", idx_dims);
// The first dimension of NegIndices will be set correcttly in Compute.
ctx->SetOutputDim("NegIndices", {-1, 1});
} }
protected: protected:
......
...@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -26,8 +26,9 @@ class NCCLInitOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
const auto &name = Output("Communicator"); const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name); "Can not find variable '%s' in the scope.", name);
......
...@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase { ...@@ -57,20 +57,6 @@ class NetOp : public framework::OperatorBase {
this->CompleteAddOp(); this->CompleteAddOp();
} }
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
for (auto& op : ops_) {
op->Run(scope, place);
}
}
bool SupportGPU() const override { bool SupportGPU() const override {
for (auto& op : ops_) { for (auto& op : ops_) {
if (!op->SupportGPU()) { if (!op->SupportGPU()) {
...@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase { ...@@ -117,6 +103,20 @@ class NetOp : public framework::OperatorBase {
std::vector<std::unique_ptr<framework::OperatorBase>> ops_; std::vector<std::unique_ptr<framework::OperatorBase>> ops_;
private: private:
/**
* @brief Run the network.
*
* Run all the operators with the `scope`, if no scope is provided, default
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
for (auto& op : ops_) {
op->Run(scope, place);
}
}
bool add_op_done_{false}; bool add_op_done_{false};
std::set<std::string> intermediate_outputs_; std::set<std::string> intermediate_outputs_;
......
...@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase { ...@@ -26,7 +26,10 @@ class TestOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp); DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope, const platform::Place& place) const override {
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {
++run_cnt; ++run_cnt;
} }
}; };
......
...@@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -118,8 +118,9 @@ class ParallelDoOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
...@@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -207,8 +208,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *block = Attr<framework::BlockDesc *>(kParallelBlock); auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
......
...@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -130,8 +130,9 @@ class TensorPrintOp : public framework::OperatorBase {
PADDLE_THROW("Not implemented."); PADDLE_THROW("Not implemented.");
} }
void Run(const framework::Scope& scope, private:
const platform::Place& place) const override { void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
const framework::Variable* in_var_ptr = nullptr; const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward; std::string phase = kForward;
std::string printed_var_name = ""; std::string printed_var_name = "";
......
...@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -54,8 +54,10 @@ class ReadInferVarType : public framework::VarTypeInference {
class ReadOp : public framework::OperatorBase { class ReadOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) { if (!reader->HasNext()) {
......
...@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -226,8 +226,9 @@ class RecurrentOp : public RecurrentBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len; VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len); StepScopes scopes = CreateStepScopes(scope, seq_len);
...@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -315,8 +316,9 @@ class RecurrentGradOp : public RecurrentBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: RecurrentBase(type, inputs, outputs, attrs) {} : RecurrentBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope)); auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len); StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
......
...@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -75,8 +75,10 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = auto &x =
detail::Ref(scope.FindVar(Input("X")), detail::Ref(scope.FindVar(Input("X")),
"Cannot find input lod tensor variable %s", Input("X")) "Cannot find input lod tensor variable %s", Input("X"))
......
...@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase { ...@@ -24,8 +24,10 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto mem_var_name = Input("X"); auto mem_var_name = Input("X");
auto *mem_var = scope.FindVar(mem_var_name); auto *mem_var = scope.FindVar(mem_var_name);
PADDLE_ENFORCE(mem_var != nullptr, PADDLE_ENFORCE(mem_var != nullptr,
...@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -76,8 +78,10 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto out_grad_var_name = Input(framework::GradVarName("Out")); auto out_grad_var_name = Input(framework::GradVarName("Out"));
auto *out_grad_var = scope.FindVar(out_grad_var_name); auto *out_grad_var = scope.FindVar(out_grad_var_name);
......
...@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -63,8 +63,10 @@ class SaveCombineOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
......
...@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase { ...@@ -62,8 +62,10 @@ class SaveOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
......
...@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -27,8 +27,9 @@ class ShrinkRNNMemoryOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *x_var = scope.FindVar(Input("X")); auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set"); PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
auto &x_tensor = x_var->Get<framework::LoDTensor>(); auto &x_tensor = x_var->Get<framework::LoDTensor>();
...@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -108,8 +109,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
......
...@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -33,8 +33,10 @@ class SplitLoDTensorOp : public framework::OperatorBase {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto *out_true = auto *out_true =
......
...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride); in_stride, out_stride[axis]);
input_offset += out_stride[axis]; input_offset += out_stride[axis];
} }
} }
......
...@@ -54,7 +54,8 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -54,7 +54,8 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst, int64_t axis, T* dst,
const framework::DDim& dst_stride_numel, const framework::DDim& dst_stride_numel,
const T* src, const T* src,
const framework::DDim& src_stride_numel) { const framework::DDim& src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis]; int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis];
...@@ -82,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -82,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place, memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after); src + i * src_after, sizeof(T) * size);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place); auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx = auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx); reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place, memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after, src + i * src_after, sizeof(T) * size, cuda_ctx.stream());
cuda_ctx.stream());
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
#endif #endif
......
...@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp { ...@@ -24,8 +24,9 @@ class WriteToArrayOp : public ArrayOp {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
if (x == nullptr) return; if (x == nullptr) return;
auto &x_tensor = x->Get<framework::LoDTensor>(); auto &x_tensor = x->Get<framework::LoDTensor>();
...@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -122,8 +123,10 @@ class ReadFromArrayOp : public ArrayOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ArrayOp(type, inputs, outputs, attrs) {} : ArrayOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override { private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set"); PADDLE_ENFORCE(x != nullptr, "X must be set");
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
......
...@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase { ...@@ -39,8 +39,9 @@ class WhileOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &dev_place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
...@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,8 +100,9 @@ class WhileGradOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {} : framework::OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, private:
const platform::Place &dev_place) const override { void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
......
...@@ -204,6 +204,17 @@ function gen_capi_package() { ...@@ -204,6 +204,17 @@ function gen_capi_package() {
fi fi
} }
function gen_fluid_inference_lib() {
if [ ${WITH_C_API:-OFF} == "OFF" ] ; then
cat <<EOF
========================================
Building fluid inference library ...
========================================
EOF
make inference_lib_dist
fi
}
set -xe set -xe
cmake_gen ${PYTHON_ABI:-""} cmake_gen ${PYTHON_ABI:-""}
...@@ -212,6 +223,7 @@ run_test ...@@ -212,6 +223,7 @@ run_test
gen_docs gen_docs
gen_dockerfile gen_dockerfile
gen_capi_package gen_capi_package
gen_fluid_inference_lib
if [[ ${WITH_C_API:-OFF} == "ON" ]]; then if [[ ${WITH_C_API:-OFF} == "ON" ]]; then
printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n" printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n"
......
...@@ -121,6 +121,7 @@ def split_dense_variable(var_list, ...@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
block_size += dim1 - remains block_size += dim1 - remains
# update split_count after aligning # update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size))) split_count = int(math.ceil(var_numel / float(block_size)))
print("###split var ", var.name, var.shape, block_size, split_count)
for block_id in xrange(split_count): for block_id in xrange(split_count):
curr_block_size = min(block_size, var_numel - ( curr_block_size = min(block_size, var_numel - (
(block_id) * block_size)) (block_id) * block_size))
...@@ -191,7 +192,6 @@ class DistributeTranspiler: ...@@ -191,7 +192,6 @@ class DistributeTranspiler:
for b in param_blocks: for b in param_blocks:
varname, block_id, _ = b.split(":") varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)]) send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same # let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs. # order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints) eplist = split_method(send_inputs, pserver_endpoints)
...@@ -230,21 +230,6 @@ class DistributeTranspiler: ...@@ -230,21 +230,6 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
self.lr_param_mapping = self._create_lr_param_mapping()
def _create_lr_param_mapping(self):
lr_mapping = dict()
for _, opt_op in enumerate(self.optimize_ops):
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
or not opt_op.inputs.has_key("Param"):
continue
lr = opt_op.inputs["LearningRate"].name
param = opt_op.inputs["Param"].name
if not lr_mapping.has_key(lr):
lr_mapping.update({lr: list()})
lr_mapping[lr].append(param)
return lr_mapping
def _create_vars_from_blocklist(self, program, block_list): def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list # Create respective variables using the block_list
block_map = dict() block_map = dict()
...@@ -271,6 +256,7 @@ class DistributeTranspiler: ...@@ -271,6 +256,7 @@ class DistributeTranspiler:
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
print("###splited: ", size, rows, splited_shape)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, psersistable=False,
...@@ -278,6 +264,7 @@ class DistributeTranspiler: ...@@ -278,6 +264,7 @@ class DistributeTranspiler:
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
print("###created split var ", var)
return var_mapping return var_mapping
def _clone_var(self, block, var): def _clone_var(self, block, var):
...@@ -369,18 +356,9 @@ class DistributeTranspiler: ...@@ -369,18 +356,9 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _fetch_var_names(self, param_dict):
res = []
if not param_dict:
return res
for _, values in param_dict.iteritems():
if not isinstance(values, list):
values = [values]
res += [v.name for v in values]
return res
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
...@@ -395,11 +373,11 @@ class DistributeTranspiler: ...@@ -395,11 +373,11 @@ class DistributeTranspiler:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
merged_var = program.global_block().vars[grad_block.name] merged_var = pserver_block.vars[grad_block.name]
# append merging ops if trainers > 1 # append merging ops if trainers > 1
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = self._create_var_for_trainers(
program.global_block(), grad_block, self.trainers) pserver_block, grad_block, self.trainers)
optimize_block.append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
...@@ -419,29 +397,27 @@ class DistributeTranspiler: ...@@ -419,29 +397,27 @@ class DistributeTranspiler:
break break
if not param_block: if not param_block:
return return
tmpvar = program.global_block().create_var( tmpvar = pserver_block.create_var(
name=param_block.name, name=param_block.name,
persistable=True, persistable=True,
dtype=param_block.dtype, dtype=param_block.dtype,
shape=param_block.shape) shape=param_block.shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
elif key == "LearningRate": elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op, # leraning rate variable has already be created by non-optimize op,
# don't create it once again. # don't create it once again.
new_inputs[key] = program.global_block().vars[opt_op.input(key)[ new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
0]]
for key in opt_op.input_names: for key in opt_op.input_names:
new_shape = None new_shape = None
if key in ["Param", "Grad", "LearningRate"]: if key in ["Param", "Grad", "LearningRate"]:
continue continue
var = program.global_block().vars[opt_op.input(key)[0]] var = self.program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = pserver_block.create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
...@@ -449,11 +425,14 @@ class DistributeTranspiler: ...@@ -449,11 +425,14 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# change output's ParamOut variable # change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"] outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
...@@ -497,11 +476,12 @@ class DistributeTranspiler: ...@@ -497,11 +476,12 @@ class DistributeTranspiler:
# If one op's input is another op's output or # If one op's input is another op's output or
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
op1_input_names = self._fetch_var_names(op1.inputs) op1_input_names = op1.desc.input_arg_names()
op1_output_names = self._fetch_var_names(op1.outputs) op1_output_names = op1.desc.output_arg_names()
op2_input_names = op2.desc.input_arg_names()
op2_output_names = op2.desc.output_arg_names()
op2_input_names = self._fetch_var_names(op2.inputs)
op2_output_names = self._fetch_var_names(op2.outputs)
if set(op1_output_names) & set(op2_input_names) or \ if set(op1_output_names) & set(op2_input_names) or \
set(op1_input_names) & set(op2_output_names): set(op1_input_names) & set(op2_output_names):
return True return True
...@@ -521,8 +501,8 @@ class DistributeTranspiler: ...@@ -521,8 +501,8 @@ class DistributeTranspiler:
def _is_opt_op(self, op): def _is_opt_op(self, op):
# NOTE: It's a HACK implement. # NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if op.inputs and op.inputs.has_key("Param") \ if "Param" in op.input_names and \
and op.inputs.has_key("LearningRate"): "LearningRate" in op.input_names:
return True return True
return False return False
...@@ -530,12 +510,12 @@ class DistributeTranspiler: ...@@ -530,12 +510,12 @@ class DistributeTranspiler:
param_names = [ param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
if op.inputs["Param"].name in param_names: if op.input("Param") in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
param = op.inputs["Param"].name param = op.input("Param")[0]
if same_or_split_var(n, param) and n != op.inputs["Param"].name: if same_or_split_var(n, param) and n != param:
return True return True
return False return False
return False return False
...@@ -551,6 +531,8 @@ class DistributeTranspiler: ...@@ -551,6 +531,8 @@ class DistributeTranspiler:
""" """
# step5 # step5
pserver_program = Program() pserver_program = Program()
print("param mapping on pserver: #### ",
self.param_grad_ep_mapping[endpoint]["params"])
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]: for v in self.param_grad_ep_mapping[endpoint]["grads"]:
...@@ -564,7 +546,6 @@ class DistributeTranspiler: ...@@ -564,7 +546,6 @@ class DistributeTranspiler:
persistable=True, persistable=True,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
# step6 # step6
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 6.1 # step 6.1
......
...@@ -400,9 +400,6 @@ class Operator(object): ...@@ -400,9 +400,6 @@ class Operator(object):
""" """
self.block = block self.block = block
self.desc = desc self.desc = desc
# for clone a new operator
self.inputs = inputs
self.outputs = outputs
self.attrs = attrs self.attrs = attrs
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
......
...@@ -16,8 +16,6 @@ import ops ...@@ -16,8 +16,6 @@ import ops
from ops import * from ops import *
import nn import nn
from nn import * from nn import *
import detection
from detection import *
import io import io
from io import * from io import *
import tensor import tensor
...@@ -33,7 +31,6 @@ from detection import * ...@@ -33,7 +31,6 @@ from detection import *
__all__ = [] __all__ = []
__all__ += math_op_patch.__all__ __all__ += math_op_patch.__all__
__all__ += detection.__all__
__all__ += nn.__all__ __all__ += nn.__all__
__all__ += io.__all__ __all__ += io.__all__
__all__ += tensor.__all__ __all__ += tensor.__all__
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -15,20 +15,32 @@ ...@@ -15,20 +15,32 @@
All layers just related to the detection neural network. All layers just related to the detection neural network.
""" """
from layer_function_generator import generate_layer_fn
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..framework import Variable
import tensor import tensor
import ops import ops
import nn import nn
import math import math
__all__ = [ __all__ = [
'detection_output',
'prior_box', 'prior_box',
'multi_box_head', 'multi_box_head',
'bipartite_match',
'target_assign',
'detection_output',
'ssd_loss',
]
__auto__ = [
'iou_similarity',
'box_coder',
] ]
__all__ += __auto__
for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP)
def detection_output(scores, def detection_output(scores,
loc, loc,
...@@ -98,18 +110,13 @@ def detection_output(scores, ...@@ -98,18 +110,13 @@ def detection_output(scores,
""" """
helper = LayerHelper("detection_output", **locals()) helper = LayerHelper("detection_output", **locals())
decoded_box = helper.create_tmp_variable(dtype=loc.dtype) decoded_box = box_coder(
helper.append_op( prior_box=prior_box,
type="box_coder", prior_box_var=prior_box_var,
inputs={ target_box=loc,
'PriorBox': prior_box, code_type='decode_center_size')
'PriorBoxVar': prior_box_var,
'TargetBox': loc
},
outputs={'OutputBox': decoded_box},
attrs={'code_type': 'decode_center_size'})
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
helper.append_op( helper.append_op(
type="multiclass_nms", type="multiclass_nms",
inputs={'Scores': scores, inputs={'Scores': scores,
...@@ -280,22 +287,22 @@ def prior_box(inputs, ...@@ -280,22 +287,22 @@ def prior_box(inputs,
if aspect_ratios: if aspect_ratios:
_is_list_or_tuple_and_equal( _is_list_or_tuple_and_equal(
aspect_ratios, num_layer, aspect_ratios, num_layer,
'aspect_ratios should be list and the length of inputs ' 'aspect_ratios should be list or tuple, and the length of inputs '
'and aspect_ratios should be the same.') 'and aspect_ratios should be the same.')
if step_h: if step_h:
_is_list_or_tuple_and_equal( _is_list_or_tuple_and_equal(
step_h, num_layer, step_h, num_layer,
'step_h should be list and the length of inputs and ' 'step_h should be list or tuple, and the length of inputs and '
'step_h should be the same.') 'step_h should be the same.')
if step_w: if step_w:
_is_list_or_tuple_and_equal( _is_list_or_tuple_and_equal(
step_w, num_layer, step_w, num_layer,
'step_w should be list and the length of inputs and ' 'step_w should be list or tuple, and the length of inputs and '
'step_w should be the same.') 'step_w should be the same.')
if steps: if steps:
_is_list_or_tuple_and_equal( _is_list_or_tuple_and_equal(
steps, num_layer, steps, num_layer,
'steps should be list and the length of inputs and ' 'steps should be list or tuple, and the length of inputs and '
'step_w should be the same.') 'step_w should be the same.')
step_w = steps step_w = steps
step_h = steps step_h = steps
...@@ -339,6 +346,331 @@ def prior_box(inputs, ...@@ -339,6 +346,331 @@ def prior_box(inputs,
return box, var return box, var
def bipartite_match(dist_matrix, name=None):
"""
**Bipartite matchint operator**
This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input
distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row, also can find the matched row for
each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix.
There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices.
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
If Tensor, the height of ColToRowMatchIndices is 1.
Args:
dist_matrix(Variable): This input is a 2-D LoDTensor with shape
[K, M]. It is pair-wise distance matrix between the entities
represented by each row and each column. For example, assumed one
entity is A with shape [K], another entity is B with shape [M]. The
dist_matirx[i][j] is the distance between A[i] and B[j]. The bigger
the distance is, the better macthing the pairs are. Please note,
This tensor can contain LoD information to represent a batch of
inputs. One instance of this batch can contain different numbers of
entities.
Returns:
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type.
N is the batch size. If match_indices[i][j] is -1, it
means B[j] does not match any entity in i-th instance.
Otherwise, it means B[j] is matched to row
match_indices[i][j] in i-th instance. The row number of
i-th instance is saved in match_indices[i][j].
match_distance(Variable): A 2-D Tensor with shape [N, M] in float type.
N is batch size. If match_indices[i][j] is -1,
match_distance[i][j] is also -1.0. Otherwise, assumed
match_distance[i][j] = d, and the row offsets of each instance
are called LoD. Then match_distance[i][j] = dist_matrix[d+LoD[i]][j].
"""
helper = LayerHelper('bipartite_match', **locals())
match_indices = helper.create_tmp_variable(dtype='int32')
match_distance = helper.create_tmp_variable(dtype=dist_matrix.dtype)
helper.append_op(
type='bipartite_match',
inputs={'DistMat': dist_matrix},
outputs={
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_distance
})
return match_indices, match_distance
def target_assign(input,
matched_indices,
negative_indices=None,
mismatch_value=None,
name=None):
"""
**Target assigner operator**
This operator can be, for given the target bounding boxes or labels,
to assign classification and regression targets to each prediction as well as
weights to prediction. The weights is used to specify which prediction would
not contribute to training loss.
For each instance, the output `out` and`out_weight` are assigned based on
`match_indices` and `negative_indices`.
Assumed that the row offset for each instance in `input` is called lod,
this operator assigns classification/regression targets by performing the
following steps:
1. Assigning all outpts based on `match_indices`:
If id = match_indices[i][j] > 0,
out[i][j][0 : K] = X[lod[i] + id][j % P][0 : K]
out_weight[i][j] = 1.
Otherwise,
out[j][j][0 : K] = {mismatch_value, mismatch_value, ...}
out_weight[i][j] = 0.
2. Assigning out_weight based on `neg_indices` if `neg_indices` is provided:
Assumed that the row offset for each instance in `neg_indices` is called neg_lod,
for i-th instance and each `id` of neg_indices in this instance:
out[i][id][0 : K] = {mismatch_value, mismatch_value, ...}
out_weight[i][id] = 1.0
Args:
inputs (Variable): This input is a 3D LoDTensor with shape [M, P, K].
matched_indices (Variable): Tensor<int>), The input matched indices
is 2D Tenosr<int32> with shape [N, P], If MatchIndices[i][j] is -1,
the j-th entity of column is not matched to any entity of row in
i-th instance.
negative_indices (Variable): The input negative example indices are
an optional input with shape [Neg, 1] and int32 type, where Neg is
the total number of negative example indices.
mismatch_value (float32): Fill this value to the mismatched location.
Returns:
out (Variable): The output is a 3D Tensor with shape [N, P, K],
N and P is the same as they are in `neg_indices`, K is the
same as it in input of X. If `match_indices[i][j]`.
out_weight (Variable): The weight for output with the shape of [N, P, 1].
"""
helper = LayerHelper('target_assign', **locals())
out = helper.create_tmp_variable(dtype=input.dtype)
out_weight = helper.create_tmp_variable(dtype='float32')
helper.append_op(
type='target_assign',
inputs={
'X': input,
'MatchIndices': matched_indices,
'NegIndices': negative_indices
},
outputs={'Out': out,
'OutWeight': out_weight},
attrs={'mismatch_value': mismatch_value})
return out, out_weight
def ssd_loss(location,
confidence,
gt_box,
gt_label,
prior_box,
prior_box_var=None,
background_label=0,
overlap_threshold=0.5,
neg_pos_ratio=3.0,
neg_overlap=0.5,
loc_loss_weight=1.0,
conf_loss_weight=1.0,
match_type='per_prediction',
mining_type='max_negative',
sample_size=None):
"""
**Multi-box loss layer for object dection algorithm of SSD**
This layer is to compute dection loss for SSD given the location offset
predictions, confidence predictions, prior boxes and ground-truth boudding
boxes and labels, and the type of hard example mining. The returned loss
is a weighted sum of the localization loss (or regression loss) and
confidence loss (or classification loss) by performing the following steps:
1. Find matched boundding box by bipartite matching algorithm.
1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
1.2 Compute matched boundding box by bipartite matching algorithm.
2. Compute confidence for mining hard examples
2.1. Get the target label based on matched indices.
2.2. Compute confidence loss.
3. Apply hard example mining to get the negative example indices and update
the matched indices.
4. Assign classification and regression targets
4.1. Encoded bbox according to the prior boxes.
4.2. Assign regression targets.
4.3. Assign classification targets.
5. Compute the overall objective loss.
5.1 Compute confidence loss.
5.1 Compute localization loss.
5.3 Compute the overall weighted loss.
Args:
location (Variable): The location predictions are a 3D Tensor with
shape [N, Np, 4], N is the batch size, Np is total number of
predictions for each instance. 4 is the number of coordinate values,
the layout is [xmin, ymin, xmax, ymax].
confidence (Variable): The confidence predictions are a 3D Tensor
with shape [N, Np, C], N and Np are the same as they are in
`location`, C is the class number.
gt_box (Variable): The ground-truth boudding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input.
gt_label (Variable): The ground-truth labels are a 2D LoDTensor
with shape [Ng, 1].
prior_box (Variable): The prior boxes are a 2D Tensor with shape [Np, 4].
prior_box_var (Variable): The variance of prior boxes are a 2D Tensor
with shape [Np, 4].
background_label (int): The index of background label, 0 by default.
overlap_threshold (float): If match_type is 'per_prediction', use
`overlap_threshold` to determine the extra matching bboxes when
finding matched boxes. 0.5 by default.
neg_pos_ratio (float): The ratio of the negative boxes to the positive
boxes, used only when mining_type is max_negative, 3.0 by defalut.
neg_overlap (float): The negative overlap upper bound for the unmatched
predictions. Use only when mining_type is max_negative,
0.5 by default.
sample_size (int): The max sample size of negative box, used only when
mining_type is hard_example.
loc_loss_weight (float): Weight for localization loss, 1.0 by default.
conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
match_type (str): The type of matching method during training, should
be 'bipartite' or 'per_prediction'.
mining_type (str): The hard example mining type, should be 'hard_example'
or 'max_negative', now only support `max_negative`.
Returns:
Variable: The weighted sum of the localization loss and confidence loss,
with shape [N * Np, 1], N and Np are the same as they are
in `location`.
Raises:
ValueError: If mining_type is 'hard_example', now only support
mining type of `max_negative`.
Examples:
.. code-block:: python
pb = layers.data(
name='prior_box',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
pbv = layers.data(
name='prior_box_var',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
loc = layers.data(name='target_box', shape=[10, 4], dtype='float32')
scores = layers.data(name='scores', shape=[10, 21], dtype='float32')
gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32')
gt_label = layers.data(
name='gt_label', shape=[1], lod_level=1, dtype='float32')
loss = layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv)
"""
helper = LayerHelper('ssd_loss', **locals())
if mining_type != 'max_negative':
raise ValueError("Only support mining_type == max_negative now.")
num, num_prior, num_class = confidence.shape
def __reshape_to_2d(var):
return ops.reshape(x=var, shape=[-1, var.shape[-1]])
# 1. Find matched boundding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
iou = iou_similarity(x=gt_box, y=prior_box)
# 1.2 Compute matched boundding box by bipartite matching algorithm.
matched_indices, matched_dist = bipartite_match(iou)
# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
gt_label = ops.reshape(x=gt_label, shape=gt_label.shape + (1, ))
target_label, _ = target_assign(
gt_label, matched_indices, mismatch_value=background_label)
# 2.2. Compute confidence loss.
# Reshape confidence to 2D tensor.
confidence = __reshape_to_2d(confidence)
target_label = tensor.cast(x=target_label, dtype='int64')
target_label = __reshape_to_2d(target_label)
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples
conf_loss = ops.reshape(x=conf_loss, shape=(num, num_prior))
neg_indices = helper.create_tmp_variable(dtype='int32')
dtype = matched_indices.dtype
updated_matched_indices = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='mine_hard_examples',
inputs={
'ClsLoss': conf_loss,
'LocLoss': None,
'MatchIndices': matched_indices,
'MatchDist': matched_dist,
},
outputs={
'NegIndices': neg_indices,
'UpdatedMatchIndices': updated_matched_indices
},
attrs={
'neg_pos_ratio': neg_pos_ratio,
'neg_dist_threshold': neg_pos_ratio,
'mining_type': mining_type,
'sample_size': sample_size,
})
# 4. Assign classification and regression targets
# 4.1. Encoded bbox according to the prior boxes.
encoded_bbox = box_coder(
prior_box=prior_box,
prior_box_var=prior_box_var,
target_box=gt_box,
code_type='encode_center_size')
# 4.2. Assign regression targets
target_bbox, target_loc_weight = target_assign(
encoded_bbox, updated_matched_indices, mismatch_value=background_label)
# 4.3. Assign classification targets
target_label, target_conf_weight = target_assign(
gt_label,
updated_matched_indices,
negative_indices=neg_indices,
mismatch_value=background_label)
# 5. Compute loss.
# 5.1 Compute confidence loss.
target_label = __reshape_to_2d(target_label)
target_label = tensor.cast(x=target_label, dtype='int64')
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = __reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight
# 5.2 Compute regression loss.
location = __reshape_to_2d(location)
target_bbox = __reshape_to_2d(target_bbox)
loc_loss = nn.smooth_l1(location, target_bbox)
target_loc_weight = __reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight
# 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
return loss
def multi_box_head(inputs, def multi_box_head(inputs,
num_classes, num_classes,
min_sizes=None, min_sizes=None,
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
from __future__ import print_function from __future__ import print_function
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.layers.detection as detection
from paddle.v2.fluid.framework import Program, program_guard from paddle.v2.fluid.framework import Program, program_guard
import unittest import unittest
class TestBook(unittest.TestCase): class TestDetection(unittest.TestCase):
def test_detection_output(self): def test_detection_output(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
...@@ -47,7 +46,67 @@ class TestBook(unittest.TestCase): ...@@ -47,7 +46,67 @@ class TestBook(unittest.TestCase):
out = layers.detection_output( out = layers.detection_output(
scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv) scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv)
self.assertIsNotNone(out) self.assertIsNotNone(out)
# print(str(program)) self.assertEqual(out.shape[-1], 6)
print(str(program))
def test_detection_api(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='y', shape=[4], dtype='float32')
z = layers.data(name='z', shape=[4], dtype='float32', lod_level=1)
iou = layers.iou_similarity(x=x, y=y)
bcoder = layers.box_coder(
prior_box=x,
prior_box_var=y,
target_box=z,
code_type='encode_center_size')
self.assertIsNotNone(iou)
self.assertIsNotNone(bcoder)
matched_indices, matched_dist = layers.bipartite_match(iou)
self.assertIsNotNone(matched_indices)
self.assertIsNotNone(matched_dist)
gt = layers.data(
name='gt', shape=[1, 1], dtype='int32', lod_level=1)
trg, trg_weight = layers.target_assign(
gt, matched_indices, mismatch_value=0)
self.assertIsNotNone(trg)
self.assertIsNotNone(trg_weight)
gt2 = layers.data(
name='gt2', shape=[10, 4], dtype='float32', lod_level=1)
trg, trg_weight = layers.target_assign(
gt2, matched_indices, mismatch_value=0)
self.assertIsNotNone(trg)
self.assertIsNotNone(trg_weight)
print(str(program))
def test_ssd_loss(self):
program = Program()
with program_guard(program):
pb = layers.data(
name='prior_box',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
pbv = layers.data(
name='prior_box_var',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
loc = layers.data(name='target_box', shape=[10, 4], dtype='float32')
scores = layers.data(name='scores', shape=[10, 21], dtype='float32')
gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32')
gt_label = layers.data(
name='gt_label', shape=[1], lod_level=1, dtype='int32')
loss = layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv)
self.assertIsNotNone(loss)
self.assertEqual(loss.shape[-1], 1)
print(str(program))
class TestPriorBox(unittest.TestCase): class TestPriorBox(unittest.TestCase):
...@@ -68,7 +127,7 @@ class TestPriorBox(unittest.TestCase): ...@@ -68,7 +127,7 @@ class TestPriorBox(unittest.TestCase):
conv4 = fluid.layers.conv2d(conv3, 3, 3, 2) conv4 = fluid.layers.conv2d(conv3, 3, 3, 2)
conv5 = fluid.layers.conv2d(conv4, 3, 3, 2) conv5 = fluid.layers.conv2d(conv4, 3, 3, 2)
box, var = detection.prior_box( box, var = layers.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5], inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images, image=images,
min_ratio=20, min_ratio=20,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册