提交 72b5bd93 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3036 from Canpio/dev_update_backward

update gradient operator registry mechanism
...@@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) ...@@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator) cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
OperatorBase* GradOpCreator::Create() { OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList(); BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op); CompleteGradOp(grad_op);
return grad_op; return grad_op;
} }
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
const VarIndexMap& var_map, const VarIndexMap& var_map,
const std::vector<int>& format, const std::vector<int>& format,
InOutType type) { InOutType type) {
...@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, ...@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx); end_idx);
} }
void GradOpCreator::BuildOpInOutArgList() { void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format = const std::vector<int>& in_format =
...@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() { ...@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
} }
} }
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out, std::vector<std::string>& in_out,
std::vector<int>& format, std::vector<int>& format,
VarIndexMap* varmap, int& idx, VarIndexMap* varmap, int& idx,
...@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format.push_back(in_out.size()); format.push_back(in_out.size());
} }
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_; grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
......
...@@ -25,12 +25,12 @@ struct OpInOutArg { ...@@ -25,12 +25,12 @@ struct OpInOutArg {
size_t end_idx_; size_t end_idx_;
}; };
class GradOpCreator { class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
public: public:
GradOpCreator(const OperatorBase* op) : op_(op) {} GradOpBuilder(const OperatorBase* op) : op_(op) {}
OperatorBase* Create(); OperatorBase* Build();
private: private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
......
#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -8,7 +8,7 @@ USE_OP(add_two); ...@@ -8,7 +8,7 @@ USE_OP(add_two);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(GradOpCreator, AddTwo) { TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op( std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op); std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -222,7 +222,7 @@ class OpRegistry { ...@@ -222,7 +222,7 @@ class OpRegistry {
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; }; op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
...@@ -245,17 +245,19 @@ class OpRegistry { ...@@ -245,17 +245,19 @@ class OpRegistry {
} }
} }
template <typename OpType> template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type) { static void RegisterGradOp(const std::string& op_type,
grad_creators()[op_type] = [] { return new OpType; }; const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = creators().find(type); auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type); "Operator %s cannot be found.", type);
auto op = op_create_it->second(); auto op = op_create_it->second();
...@@ -300,8 +302,8 @@ class OpRegistry { ...@@ -300,8 +302,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateGradOp( static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) { std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get()); GradOpBuilder builder(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create()); std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
} }
...@@ -311,9 +313,9 @@ class OpRegistry { ...@@ -311,9 +313,9 @@ class OpRegistry {
return protos_; return protos_;
}; };
static std::unordered_map<std::string, OpCreator>& grad_creators() { static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, OpCreator> grad_creators_; static std::unordered_map<std::string, std::string> grad_ops_;
return grad_creators_; return grad_ops_;
} }
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
...@@ -322,12 +324,12 @@ class OpRegistry { ...@@ -322,12 +324,12 @@ class OpRegistry {
return maps_; return maps_;
} }
private: static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator>& creators() { static std::unordered_map<std::string, OpCreator> op_creators_;
static std::unordered_map<std::string, OpCreator> creators_; return op_creators_;
return creators_;
} }
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() { static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_; return op_checkers_;
...@@ -353,11 +355,11 @@ class OpRegisterHelper { ...@@ -353,11 +355,11 @@ class OpRegisterHelper {
} }
}; };
template <typename OpType> template <typename GradOpType>
class GradOpRegisterHelper { class GradOpRegisterHelper {
public: public:
GradOpRegisterHelper(const char* op_type) { GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type); OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -383,13 +385,16 @@ class GradOpRegisterHelper { ...@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/** /**
* Macro to Register Gradient Operator. * Macro to Register Gradient Operator.
*/ */
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ #define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \ __reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \ "REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \ static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \ __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
int __op_gradient_register_##__op_type##_handle__() { return 0; } #__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
......
...@@ -65,6 +65,6 @@ protected: ...@@ -65,6 +65,6 @@ protected:
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>); add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) { ...@@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos(); auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two"); auto it = protos.find("add_two");
ASSERT_NE(it, protos.end()); ASSERT_NE(it, protos.end());
auto& grad_creators = paddle::framework::OpRegistry::grad_creators(); auto& op_creators = paddle::framework::OpRegistry::op_creators();
auto it1 = grad_creators.find("add_two"); auto it1 = op_creators.find("add_two_grad");
ASSERT_NE(it1, grad_creators.end()); ASSERT_NE(it1, op_creators.end());
} }
...@@ -67,7 +67,7 @@ protected: ...@@ -67,7 +67,7 @@ protected:
} // namespace paddle } // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad); REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>); mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -56,7 +56,7 @@ protected: ...@@ -56,7 +56,7 @@ protected:
REGISTER_OP(sigmoid, REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp, paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker); paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid, sigmoid,
......
...@@ -59,6 +59,6 @@ protected: ...@@ -59,6 +59,6 @@ protected:
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad); REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册