提交 9806e7f2 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3522 from reyoung/feature/clone_op

Feature/clone op
......@@ -144,8 +144,18 @@ class OpKernelRegistrar : public Registrar {
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
grad_op_class> \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
......@@ -176,7 +186,8 @@ class OpKernelRegistrar : public Registrar {
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
* link them into target.
*/
#define USE_OP_ITSELF(op_type) \
......@@ -196,7 +207,8 @@ class OpKernelRegistrar : public Registrar {
__attribute__((unused)) = \
TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
// TODO(fengjiayi): The following macros seems ugly, do we have better method?
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
......
......@@ -67,10 +67,6 @@ class OperatorBase {
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs);
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
OperatorBase(OperatorBase&& o) = delete;
virtual ~OperatorBase() {}
template <typename T>
......@@ -116,10 +112,14 @@ class OperatorBase {
void SetType(const std::string& type) { type_ = type; }
const AttributeMap& Attrs() const { return attrs_; }
// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
protected:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// I (Inputs)opear
// O (Outputs)
// OG (Output Gradients)
VarNameMap inputs_;
......@@ -130,12 +130,32 @@ class OperatorBase {
AttributeMap attrs_;
};
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(CLS) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new CLS(*this)); \
}
// Macro for define a default constructor for Operator.
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this));
}
};
// this class not only make proto but also init attribute checkers.
......
......@@ -245,3 +245,21 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
}
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const paddle::framework::Scope& scope) const override {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
};
TEST(Operator, Clone) {
OperatorClone a("ABC", {}, {}, {});
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
\ No newline at end of file
......@@ -85,7 +85,14 @@ NetOp::NetOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
: framework::OperatorBase(type, inputs, outputs, attrs) {}
std::unique_ptr<framework::OperatorBase> NetOp::Clone() const {
PADDLE_ENFORCE(
add_op_done_,
"Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone");
return std::unique_ptr<OperatorBase>(new NetOp(*this));
}
} // namespace operators
} // namespace paddle
......@@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase {
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
NetOp(const NetOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size());
std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
[](const std::shared_ptr<OperatorBase>& op)
-> std::shared_ptr<OperatorBase> {
return std::shared_ptr<OperatorBase>(op->Clone());
});
this->CompleteAddOp();
}
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
......@@ -98,6 +110,8 @@ class NetOp : public framework::OperatorBase {
bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
std::unique_ptr<framework::OperatorBase> Clone() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
......
......@@ -13,6 +13,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -70,5 +71,21 @@ TEST(NetOp, insert_op) {
ASSERT_EQ(3UL, net.ops_.size());
}
TEST(NetOp, Clone) {
NetOp net;
net.AddOp(
std::shared_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::shared_ptr<framework::NOP>(
new framework::NOP{"empty2", {}, {}, {}}));
net.CompleteAddOp(true);
auto new_net_op = net.Clone();
ASSERT_NE(new_net_op, nullptr);
ASSERT_TRUE(new_net_op->IsNetOp());
auto* new_net = static_cast<NetOp*>(new_net_op.get());
ASSERT_EQ(2, new_net->ops_.size());
ASSERT_EQ(new_net->ops_[0]->Type(), "empty");
ASSERT_EQ(new_net->ops_[1]->Type(), "empty2");
}
} // namespace operators
} // namespace paddle
......@@ -110,13 +110,20 @@ class RecurrentGradientAlgorithm {
std::shared_ptr<NetOp>* stepnet_;
};
class RecurrentOp final : public framework::OperatorBase {
class RecurrentOp : public framework::OperatorBase {
public:
RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
RecurrentOp(const RecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
/**
* InferShape must be called before Run.
*/
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
......@@ -137,12 +144,19 @@ class RecurrentOp final : public framework::OperatorBase {
std::shared_ptr<NetOp> stepnet_;
};
class RecurrentGradientOp final : public framework::OperatorBase {
class RecurrentGradientOp : public framework::OperatorBase {
public:
RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const framework::AttributeMap& attrs);
RecurrentGradientOp(const RecurrentGradientOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement Copy ctor.
PADDLE_THROW("Not Implemented");
}
/**
* InferShape must be called before Run.
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册