提交 c4effc7d 编写于 作者: Y Yu Yang

Fix CI Test

上级 e119177a
......@@ -33,7 +33,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.grad_op_maker_(op_desc);
auto grad_descs = info.GradOpMaker()(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
......@@ -49,6 +49,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
net_op->CompleteAddOp();
return std::unique_ptr<OperatorBase>(net_op);
}
}
......
......@@ -30,7 +30,6 @@ namespace framework {
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
......@@ -51,6 +50,12 @@ struct OpInfo {
"Operator Creator has not been registered");
return creator_;
}
const GradOpMakerFN& GradOpMaker() const {
PADDLE_ENFORCE_NOT_NULL(grad_op_maker_,
"Operator GradOpMaker has not been registered.");
return grad_op_maker_;
}
};
class OpInfoMap {
......
......@@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define VA_ARGS(...) , ##__VA_ARGS__
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \
__VA_ARGS__)> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/**
......@@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar {
virtual std::string GradOpType() const { return #grad_op_type; } \
}; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class)
op_maker_class);
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
......
......@@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
const override {
std::vector<std::unique_ptr<framework::OpDescBind>> ops;
ops.resize(2);
ops[0].reset(new framework::OpDescBind());
ops[0]->SetType("scale");
ops[0]->SetInput("X", OutputGrad("Out"));
ops[0]->SetOutput("Out", InputGrad("X"));
ops[0]->SetAttr("scale", 1.0f);
ops[1].reset(new framework::OpDescBind());
ops[1]->SetType("scale");
ops[1]->SetInput("X", OutputGrad("Out"));
ops[1]->SetOutput("Out", InputGrad("Y"));
ops[1]->SetAttr("scale", -1.0f);
auto x_g = InputGrad("X");
if (!x_g.empty()) {
auto *x_g_op = new framework::OpDescBind();
x_g_op->SetType("scale");
x_g_op->SetInput("X", OutputGrad("Out"));
x_g_op->SetOutput("Out", x_g);
x_g_op->SetAttr("scale", 1.0f);
ops.emplace_back(x_g_op);
}
auto y_g = InputGrad("Y");
if (!y_g.empty()) {
auto *y_g_op = new framework::OpDescBind();
y_g_op->SetType("scale");
y_g_op->SetInput("X", OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g);
y_g_op->SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
}
return ops;
}
};
......
......@@ -121,6 +121,7 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("pad_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
......
......@@ -14,6 +14,12 @@
#include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h>
#include <iostream>
#define DBG_LINE() \
do { \
std::cerr << "Run at " << __LINE__ << std::endl; \
} while (false)
namespace paddle {
namespace operators {
......@@ -187,8 +193,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxWithCrossEntropyOpMaker);
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册