提交 ddb29b6c 编写于 作者: Y Yi Wang

Move constants from framework::OperatorBase to framework::

上级 6f12fd28
...@@ -59,7 +59,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -59,7 +59,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, kGradVarSuffix,
no_grad_names)) { no_grad_names)) {
return NOP(); return NOP();
} }
...@@ -67,11 +67,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -67,11 +67,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.outputs_, kGradVarSuffix,
no_grad_names)) { no_grad_names)) {
for (auto& name : forwardOp.inputs_) { for (auto& name : forwardOp.inputs_) {
// Mark all input is not need // Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + kGradVarSuffix);
} }
return NOP(); return NOP();
} }
...@@ -135,8 +135,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -135,8 +135,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (std::string& grad_input : grad_op->inputs_) { for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); 0, grad_input.size() - kGradVarSuffix.size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); grad_input = prefix + kZeroVarSuffix;
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
...@@ -147,7 +147,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -147,7 +147,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (std::string& grad_output : grad_op->outputs_) { for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) { if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME(); grad_output = kEmptyVarName;
} }
} }
...@@ -168,11 +168,11 @@ std::shared_ptr<OperatorBase> Backward( ...@@ -168,11 +168,11 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size()); no_grad_names.reserve(no_grad_vars.size());
no_grad_names.insert(OperatorBase::EMPTY_VAR_NAME() + no_grad_names.insert(kEmptyVarName +
OperatorBase::GRAD_VAR_SUFFIX()); kGradVarSuffix);
for (auto& name : no_grad_vars) { for (auto& name : no_grad_vars) {
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + kGradVarSuffix);
} }
size_t uid = 0; size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid); return BackwardRecursive(forwardOp, no_grad_names, uid);
......
...@@ -78,14 +78,14 @@ class FcOp : public ops::NetOp { ...@@ -78,14 +78,14 @@ class FcOp : public ops::NetOp {
{Output("mul_result")}, {})); {Output("mul_result")}, {}));
auto b_name = Input("b"); auto b_name = Input("b");
std::string before_act = "mul_result"; std::string before_act = "mul_result";
if (b_name != EMPTY_VAR_NAME()) { if (b_name != kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("add_result")}, {})); {Output("add_result")}, {}));
before_act = "add_result"; before_act = "add_result";
} else { } else {
auto out_varname = Output("add_result"); auto out_varname = Output("add_result");
if (out_varname != EMPTY_VAR_NAME()) { if (out_varname != kEmptyVarName) {
this->Rename(out_varname, EMPTY_VAR_NAME()); this->Rename(out_varname, kEmptyVarName);
} }
} }
...@@ -163,13 +163,13 @@ TEST(Backward, simple_op_grad) { ...@@ -163,13 +163,13 @@ TEST(Backward, simple_op_grad) {
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(4UL, gop->inputs_.size()); ASSERT_EQ(4UL, gop->inputs_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]); ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), ASSERT_EQ("X" + f::kGradVarSuffix,
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); gop->Output("X" + f::kGradVarSuffix));
} }
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
...@@ -177,7 +177,7 @@ TEST(Backward, simple_op_not_need_grad) { ...@@ -177,7 +177,7 @@ TEST(Backward, simple_op_not_need_grad) {
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"}); auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()), "X" + f::kGradVarSuffix),
gop->outputs_.end()); gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"}); auto no_input_gop = f::Backward(*fwd, {"X", "b"});
...@@ -211,7 +211,7 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -211,7 +211,7 @@ TEST(Backward, net_fc_backward_normal) {
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, "fc", {"X", "w", f::kEmptyVarName},
{"mul_result", "add_result", "tmp"}, {}); {"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
...@@ -242,15 +242,15 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -242,15 +242,15 @@ TEST(Backward, net_input_of_network_not_need_grad) {
std::unordered_set<std::string> all_output = std::unordered_set<std::string>( std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end()); bwd_net->outputs_.begin(), bwd_net->outputs_.end());
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); all_output.erase(f::kEmptyVarName);
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_NE(all_output.find(out + f::kGradVarSuffix),
all_output.end()); all_output.end());
} }
// Not Generated X // Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix),
all_output.end()); all_output.end());
ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
...@@ -258,8 +258,8 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -258,8 +258,8 @@ TEST(Backward, net_input_of_network_not_need_grad) {
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ( ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(), f::kEmptyVarName,
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
...@@ -311,17 +311,17 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -311,17 +311,17 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ(1UL, fill_zero.inputs_.size()); ASSERT_EQ(1UL, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]); ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1UL, fill_zero.outputs_.size()); ASSERT_EQ(1UL, fill_zero.outputs_.size());
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]); ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), ASSERT_EQ("Z" + f::kZeroVarSuffix,
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX())); d_many_out.Input("z" + f::kGradVarSuffix));
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(), ASSERT_EQ("Y" + f::kGradVarSuffix,
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); d_many_out.Input("y" + f::kGradVarSuffix));
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), ASSERT_EQ("X" + f::kGradVarSuffix,
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); d_many_out.Output("x" + f::kGradVarSuffix));
} }
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
...@@ -331,12 +331,12 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -331,12 +331,12 @@ TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix),
f::OperatorBase::EMPTY_VAR_NAME()); f::kEmptyVarName);
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX()); "b" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX()); "out" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("A"), "a"); ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b"); ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out"); ASSERT_EQ(grad_mul.Input("Out"), "out");
...@@ -370,17 +370,17 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -370,17 +370,17 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
/* /*
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("X" + f::kGradVarSuffix),
f::OperatorBase::EMPTY_VAR_NAME()); f::kEmptyVarName);
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("W" + f::kGradVarSuffix),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "w3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("b" + f::kGradVarSuffix),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "b3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("mul_result" + f::kGradVarSuffix),
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "mul_out3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Input("Out" + f::kGradVarSuffix),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "out3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Input("X"), "out2"); EXPECT_EQ(grad_fc.Input("X"), "out2");
EXPECT_EQ(grad_fc.Input("W"), "w3"); EXPECT_EQ(grad_fc.Input("W"), "w3");
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3"); EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
......
...@@ -57,7 +57,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -57,7 +57,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); std::string src_name = arg.name();
std::string dst_name = std::string dst_name =
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name; is_grad ? src_name + kGradVarSuffix : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++; (*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin = int src_begin =
...@@ -65,9 +65,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -65,9 +65,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
int src_end = src_format == nullptr ? src_arg_idx + 1 int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1); : src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) { for (int i = src_begin; i < src_end; ++i) {
std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX() std::string s = is_grad ? src_inout[i] + kGradVarSuffix
: arg.ignore_gradient() : arg.ignore_gradient()
? OperatorBase::EMPTY_VAR_NAME() ? kEmptyVarName
: src_inout[i]; : src_inout[i];
dst_inout.emplace_back(s); dst_inout.emplace_back(s);
} }
......
...@@ -341,7 +341,7 @@ class OpRegistry { ...@@ -341,7 +341,7 @@ class OpRegistry {
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
if (outname == OperatorBase::TMP_VAR_NAME()) { if (outname == kTempVarName) {
outname += op->type_; outname += op->type_;
outname += "@"; outname += "@";
outname += std::to_string(gUniqId.fetch_add(1)); outname += std::to_string(gUniqId.fetch_add(1));
......
...@@ -32,9 +32,30 @@ limitations under the License. */ ...@@ -32,9 +32,30 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@";
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@";
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
const std::string kZeroVarSuffix = "@ZERO";
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -43,25 +64,6 @@ class ExecutionContext; ...@@ -43,25 +64,6 @@ class ExecutionContext;
*/ */
class OperatorBase { class OperatorBase {
public: public:
/// If a variable is a empty variable, that name will be used.
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
static std::string GRAD_VAR_NAME(const std::string& name) {
return name + GRAD_VAR_SUFFIX();
}
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
......
...@@ -154,8 +154,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -154,8 +154,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
.def("empty", OperatorBase::EMPTY_VAR_NAME) .def("empty", kEmptyVarName)
.def("temp", OperatorBase::TMP_VAR_NAME); .def("temp", kTempVarName);
// clang-format off // clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("create", .def_static("create",
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
{Output("before_act")}, {Output("before_act")},
{})); {}));
auto b = Input("b"); auto b = Input("b");
if (b != EMPTY_VAR_NAME()) { if (b != framework::kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add", AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")}, {Output("before_act"), Input("b")},
{Output("before_act")}, {Output("before_act")},
......
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
class MeanGradOp : public OperatorWithKernel { class MeanGradOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX()) ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims()); ->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -39,10 +39,10 @@ template <typename Place, typename T> ...@@ -39,10 +39,10 @@ template <typename Place, typename T>
class MeanGradKernel : public OpKernel { class MeanGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX()); auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1, PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar"); "Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX()); auto IG = context.Output<Tensor>("X" + framework::kGradVarSuffix);
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims()); T ig_size = (T)framework::product(IG->dims());
......
...@@ -48,12 +48,12 @@ protected: ...@@ -48,12 +48,12 @@ protected:
PADDLE_ENFORCE(ctx.OutputSize() == 1UL, PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output of SoftmaxOpGrad should be 1"); "Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr, PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
"Input(Y@GRAD) should not be null"); "Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() == PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(GRAD_VAR_NAME("Y"))->dims(), ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same"); "the shape of Input(0) and Input(1) should be the same");
ctx.Output<Tensor>(GRAD_VAR_NAME("X")) ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims()); ->Resize(ctx.Input<Tensor>("Y")->dims());
} }
}; };
......
...@@ -68,8 +68,8 @@ public: ...@@ -68,8 +68,8 @@ public:
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>(); std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y"); auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(OperatorBase::GRAD_VAR_NAME("Y")); auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(OperatorBase::GRAD_VAR_NAME("X")); auto dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0]; const int batch_size = Y->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册