提交 72b5bd93 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3036 from Canpio/dev_update_backward

update gradient operator registry mechanism
......@@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
......@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
......@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
......@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
}
}
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
......@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
......
......@@ -25,12 +25,12 @@ struct OpInOutArg {
size_t end_idx_;
};
class GradOpCreator {
class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
GradOpBuilder(const OperatorBase* op) : op_(op) {}
OperatorBase* Build();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
......
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
......@@ -8,7 +8,7 @@ USE_OP(add_two);
namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
......@@ -222,7 +222,7 @@ class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
......@@ -245,17 +245,19 @@ class OpRegistry {
}
}
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
auto op = op_create_it->second();
......@@ -300,8 +302,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
GradOpBuilder builder(op.get());
std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init();
return grad_op;
}
......@@ -311,9 +313,9 @@ class OpRegistry {
return protos_;
};
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
......@@ -322,12 +324,12 @@ class OpRegistry {
return maps_;
}
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
......@@ -353,11 +355,11 @@ class OpRegisterHelper {
}
};
template <typename OpType>
template <typename GradOpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
}
};
......@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/**
* Macro to Register OperatorKernel.
......
......@@ -65,6 +65,6 @@ protected:
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
......@@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
auto it1 = grad_creators.find("add_two");
ASSERT_NE(it1, grad_creators.end());
auto& op_creators = paddle::framework::OpRegistry::op_creators();
auto it1 = op_creators.find("add_two_grad");
ASSERT_NE(it1, op_creators.end());
}
......@@ -67,7 +67,7 @@ protected:
} // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);
REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
......@@ -56,7 +56,7 @@ protected:
REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(
sigmoid,
......
......@@ -59,6 +59,6 @@ protected:
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册