提交 11c35605 编写于 作者: Y Yu Yang

Remove empty constructor for operator

上级 0b1052fc
......@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
......@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp {
public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override {
FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
......
......@@ -23,13 +23,12 @@ class OpRegistry;
enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
bool is_grad) {
static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
auto& dst_inout = *vars;
const OpProto& proto = OpProtos().at(src_op->type_);
const auto& src_arg_list =
......@@ -47,15 +46,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG
return grad_op;
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_);
auto& grad_op_type = gop_type_it->second;
OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
}
} // namespace framework
......
......@@ -10,7 +10,7 @@ namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
......
......@@ -117,13 +117,19 @@ class OpProtoAndCheckerMaker {
};
class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
......@@ -138,29 +144,25 @@ class OpRegistry {
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type;
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
AttributeMap attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);
auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
GenerateTempVariableName(op);
auto op = op_create_it->second(type, inputs, outputs, attrs);
op->Init();
return std::shared_ptr<OperatorBase>(op);
}
......@@ -195,7 +197,6 @@ class OpRegistry {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
grad_op->Init();
return grad_op;
}
......@@ -214,19 +215,6 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += op->type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
};
class Registrar {
......
......@@ -7,7 +7,7 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
......@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......
......@@ -120,5 +120,21 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
} // namespace framework
} // namespace paddle
......@@ -66,10 +66,8 @@ class OperatorBase {
public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
const VarNameMap& outputs, const AttributeMap& attrs);
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
......@@ -86,10 +84,6 @@ class OperatorBase {
virtual std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const Scope& scope) const = 0;
......@@ -154,23 +148,14 @@ class OperatorBase {
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::map<std::string, std::vector<std::string>> inputs_;
VarNameMap inputs_;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::map<std::string, std::vector<std::string>> outputs_;
VarNameMap outputs_;
AttributeMap attrs_;
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......@@ -310,8 +295,6 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
struct OpKernelKey {
platform::Place place_;
......@@ -335,6 +318,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
}
......
......@@ -22,10 +22,10 @@ namespace framework {
static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
public:
void Init() override { x = 1; }
OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
}
public:
float x = 0;
int x{0};
};
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -104,7 +104,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
......
......@@ -18,7 +18,8 @@ namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......@@ -45,7 +46,9 @@ The equation is: Out = X + Y
};
class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
......@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -18,7 +18,8 @@ namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......
......@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};
class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& context) const override {
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
......@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
......
......@@ -18,7 +18,8 @@ namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
......@@ -53,7 +54,9 @@ The equation is: Out = X * Y
};
class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
......
......@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return ret_val;
}
NetOp::NetOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
} // namespace operators
} // namespace paddle
......@@ -37,7 +37,9 @@ namespace operators {
class NetOp : public framework::OperatorBase {
public:
static const char kAll[];
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase);
NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {}
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
/**
* Infer all the operators' input and output variables' shapes, will be called
......
......@@ -12,7 +12,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase);
using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase);
using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......@@ -44,14 +44,14 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
op1->outputs_ = {{"Out", {"y"}}};
auto op1 = std::shared_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net->AddOp(op1);
auto op2 = std::make_shared<TestOp>();
op2->inputs_ = {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}};
op2->outputs_ = {{"Out", {"z"}}};
auto op2 = std::shared_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {}));
net->AddOp(op2);
net->CompleteAddOp();
......@@ -67,9 +67,9 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) {
NetOp net;
auto op1 = std::make_shared<EmptyOp>();
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}};
op1->outputs_ = {{"Out", {"y"}}};
auto op1 = std::shared_ptr<EmptyOp>(
new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net.AddOp(op1);
net.InsertOp(0, op1);
ASSERT_EQ(2UL, net.ops_.size());
......
......@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"inlink@grad", "inlink_alias", "outlink_alias",
"memories", "pre_memories", "boot_memories@grad"};
void RecurrentOp::Init() {
OperatorBase::Init();
RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
......@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
}
void RecurrentGradientOp::Init() {
OperatorBase::Init();
RecurrentGradientOp::RecurrentGradientOp(
const std::string& type, const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
......
......@@ -101,13 +101,11 @@ class RecurrentGradientAlgorithm {
class RecurrentOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase);
void Init() override;
RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
/**
* InferShape must be called before Run.
*/
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
......@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase)
void Init() override;
RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const framework::AttributeMap& attrs);
/**
* InferShape must be called before Run.
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
......@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
......@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
......
......@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册