提交 cdd1a8f5 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3436 from Canpio/refactor_registry_macro

Merge maps in OpRegistry and simplify register macros
...@@ -155,19 +155,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -155,19 +155,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad,
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad, f::EmptyOp);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker,
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); many_output_op_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
......
...@@ -20,16 +20,14 @@ namespace paddle { ...@@ -20,16 +20,14 @@ namespace paddle {
namespace framework { namespace framework {
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
OperatorBase::VarNameMap* vars, bool is_grad, OperatorBase::VarNameMap* vars) {
const OpArgType& src_type, bool is_grad) {
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_;
const OpProto& proto = OpProtos().at(src_op->type_);
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue; if (arg.no_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
...@@ -43,22 +41,26 @@ static void TransOpArg(const OperatorBase* src_op, ...@@ -43,22 +41,26 @@ static void TransOpArg(const OperatorBase* src_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto gop_type_it = OpRegistry::grad_ops().find(op->type_); auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"Operator %s do not register gradient type", op->type_); "'%s' has not been registered.", op->type_);
auto& grad_op_type = gop_type_it->second; PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->type_);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_);
OperatorBase::VarNameMap inputs; OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs; OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O TransOpArg(op, OpArgType::OUT, false, &inputs); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG TransOpArg(op, OpArgType::IN, true, &outputs); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_); it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
return it->second.creator_(grad_op_type, inputs, outputs, op->attrs_);
} }
} // namespace framework } // namespace framework
......
...@@ -8,14 +8,6 @@ USE_OP(add_two); ...@@ -8,14 +8,6 @@ USE_OP(add_two);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class MutiInOutOpMaker : public OpProtoAndCheckerMaker { class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public: public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <type_traits> #include <type_traits>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
...@@ -119,6 +120,12 @@ class OpProtoAndCheckerMaker { ...@@ -119,6 +120,12 @@ class OpProtoAndCheckerMaker {
bool validated_{false}; bool validated_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class OpRegistry { class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*( using OpCreator = std::function<OperatorBase*(
...@@ -126,46 +133,56 @@ class OpRegistry { ...@@ -126,46 +133,56 @@ class OpRegistry {
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public: public:
template <typename OpType, typename ProtoMakerType> struct OpInfo {
static void RegisterOp(const std::string& op_type) { OpCreator creator_;
op_creators()[op_type] = []( std::string grad_op_type_;
const std::string& type, const VarNameMap& inputs, OpProto* proto_;
const VarNameMap& outputs, const AttributeMap& attrs) { OpAttrChecker* checker_;
return new OpType(type, inputs, outputs, attrs); };
};
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
op_proto.set_type(op_type);
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
}
template <typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterGradOp(const std::string& op_type, static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
op_creators()[grad_op_type] = []( PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
const std::string& type, const VarNameMap& inputs, "'%s' is registered more than once.", op_type);
const VarNameMap& outputs, const AttributeMap& attrs) { OpInfo op_info;
return new GradOpType(type, inputs, outputs, attrs); op_info.creator_ = [](const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
}; };
grad_ops()[op_type] = grad_op_type; op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
op_info_map().insert(std::make_pair(op_type, op_info));
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
AttributeMap attrs) { AttributeMap attrs) {
auto op_create_it = op_creators().find(type); auto it = op_info_map().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(it != op_info_map().end(),
"Operator %s cannot be found.", type); "Operator '%s' has not been registered.", type);
op_checkers().at(type).Check(attrs); it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
auto op = op_create_it->second(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op); return std::shared_ptr<OperatorBase>(op);
} }
...@@ -200,49 +217,32 @@ class OpRegistry { ...@@ -200,49 +217,32 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<std::string, std::string>& grad_ops() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, std::string> grad_ops_; static std::unordered_map<std::string, const OpInfo> op_info_map_;
return grad_ops_; return op_info_map_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
} }
}; };
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, have // In our design, various kinds of classes, e.g., operators and kernels,
// their corresponding registry and registrar. The action of registration is // have their corresponding registry and registrar. The action of
// in the constructor of a global registrar variable, which, however, are not // registration is in the constructor of a global registrar variable, which,
// used in the code that calls package framework, and would be removed from // however, are not used in the code that calls package framework, and would
// the generated binary file by the linker. To avoid such removal, we add // be removed from the generated binary file by the linker. To avoid such
// Touch to all registrar classes and make USE_OP macros to call this // removal, we add Touch to all registrar classes and make USE_OP macros to
// method. So, as long as the callee code calls USE_OP, the global // call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker. // registrar variable won't be removed by the linker.
void Touch() {} void Touch() {}
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistrar(const char* op_type, const char* grad_op_type) {
} OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
}; grad_op_type);
template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -268,30 +268,20 @@ class OpKernelRegistrar : public Registrar { ...@@ -268,30 +268,20 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
__op_registrar_##op_type##__(#op_type); \ grad_op_class> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
return 0; \ return 0; \
} }
/** #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
* Macro to register Gradient Operator. REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
...@@ -307,14 +297,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -307,14 +297,6 @@ class OpKernelRegistrar : public Registrar {
return 0; \ return 0; \
} }
/**
* Macro to Forbid user register Gradient Operator.
*/
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ #define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
...@@ -333,23 +315,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -333,23 +315,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \ static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type() TouchOpRegistrar_##op_type()
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
...@@ -369,18 +334,13 @@ class OpKernelRegistrar : public Registrar { ...@@ -369,18 +334,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif #endif
#define USE_NO_GRAD_OP(op_type) \ #define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_GRADIENT(op_type) USE_OP_KERNEL(op_type)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,11 +59,10 @@ static void BuildVar(const std::string& param_name, ...@@ -59,11 +59,10 @@ static void BuildVar(const std::string& param_name,
var->add_arguments(arg_name); var->add_arguments(arg_name);
} }
} }
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
REGISTER_OP(cos_sim, paddle::framework::CosineOp, paddle::framework::CosineOpProtoAndCheckerMaker);
paddle::framework::CosineOpProtoAndCheckerMaker); REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
REGISTER_OP(my_test_op, paddle::framework::MyTestOp, paddle::framework::MyTestOpProtoAndCheckerMaker);
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
......
...@@ -33,14 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,14 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
if (g_op_protos == nullptr) {
g_op_protos = new std::unordered_map<std::string, OpProto>();
}
return *g_op_protos;
}
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL, PADDLE_ENFORCE_EQ(ins.size(), 1UL,
...@@ -149,14 +141,18 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -149,14 +141,18 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
} }
return ret_val; return ret_val;
} }
auto it = OpProtos().find(type_); auto it = OpRegistry::op_info_map().find(type_);
PADDLE_ENFORCE( PADDLE_ENFORCE(
it != OpProtos().end(), it != OpRegistry::op_info_map().end(),
"Operator %s not registered, cannot figure out intermediate outputs", "Operator %s not registered, cannot figure out intermediate outputs",
type_); type_);
PADDLE_ENFORCE(
it->second.proto_ != nullptr,
"Operator %s has no OpProto, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs // get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) { for (auto& o : it->second.proto_->outputs()) {
// ignore all intermediate output // ignore all intermediate output
if (o.intermediate()) continue; if (o.intermediate()) continue;
auto out = outputs_.find(o.name()); auto out = outputs_.find(o.name());
......
...@@ -50,8 +50,6 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -50,8 +50,6 @@ inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
extern std::unordered_map<std::string, OpProto>& OpProtos();
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
...@@ -129,6 +127,14 @@ class OperatorBase { ...@@ -129,6 +127,14 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
...@@ -210,7 +216,7 @@ class InferShapeContext { ...@@ -210,7 +216,7 @@ class InferShapeContext {
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr", name, var, "MultiOutput(%s:%s) should not be nullptr.", name,
sub_name); sub_name);
return var->GetMutable<T>(); return var->GetMutable<T>();
}); });
......
...@@ -65,8 +65,9 @@ static void BuildVar(const std::string& param_name, ...@@ -65,8 +65,9 @@ static void BuildVar(const std::string& param_name,
} }
} }
REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); test_operator, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
...@@ -184,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -184,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpKernelTestProtoAndCheckerMaker); op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); paddle::framework::CPUKernelTest<float, float>);
...@@ -210,8 +212,9 @@ TEST(OpKernel, all) { ...@@ -210,8 +212,9 @@ TEST(OpKernel, all) {
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
} }
REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest); paddle::framework::CPUKernalMultiInputsTest);
......
...@@ -30,8 +30,8 @@ limitations under the License. */ ...@@ -30,8 +30,8 @@ limitations under the License. */
namespace py = pybind11; namespace py = pybind11;
USE_OP(add_two); USE_OP(add_two);
USE_CPU_OP(onehot_cross_entropy); USE_CPU_ONLY_OP(onehot_cross_entropy);
USE_NO_GRAD_OP(sgd); USE_OP(sgd);
USE_OP(mul); USE_OP(mul);
USE_OP(mean); USE_OP(mean);
USE_OP(sigmoid); USE_OP(sigmoid);
...@@ -160,13 +160,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -160,13 +160,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpProtos(); auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), const OpProto *proto = it->second.proto_;
"OpProto must all be initialized"); if (proto == nullptr) {
continue;
}
PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str), PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.push_back(py::bytes(str));
} }
......
...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>); ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -68,12 +68,11 @@ OnehotCrossEntropy Operator. ...@@ -68,12 +68,11 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy, onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -46,7 +46,8 @@ The output will have the same size with input. ...@@ -46,7 +46,8 @@ The output will have the same size with input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
...@@ -81,5 +81,6 @@ Use to initialize tensor with gaussian random generator. ...@@ -81,5 +81,6 @@ Use to initialize tensor with gaussian random generator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
...@@ -54,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -54,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>); ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
...@@ -70,7 +70,5 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -70,7 +70,5 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -246,5 +246,6 @@ RecurrentGradientOp::RecurrentGradientOp( ...@@ -246,5 +246,6 @@ RecurrentGradientOp::RecurrentGradientOp(
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp, REGISTER_OP_WITHOUT_GRADIENT(
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); recurrent_op, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
...@@ -54,6 +54,7 @@ for i in xrange(X.shape[0]): ...@@ -54,6 +54,7 @@ for i in xrange(X.shape[0]):
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp,
ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>); rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
...@@ -51,6 +51,6 @@ param_out = param - learning_rate * grad; ...@@ -51,6 +51,6 @@ param_out = param - learning_rate * grad;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, REGISTER_OP_CPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::CPUPlace, float>); ops::SGDOpKernel<paddle::platform::CPUPlace, float>);
...@@ -52,9 +52,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { ...@@ -52,9 +52,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>); ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -62,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -62,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad,
ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>); softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>);
...@@ -81,7 +81,7 @@ Used to initialize tensor with uniform random generator. ...@@ -81,7 +81,7 @@ Used to initialize tensor with uniform random generator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>); paddle::operators::CPUUniformRandomKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册