提交 a7d5b1ab 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #3217 from wangkuiyi/const

Move constants from framework::OperatorBase to framework::
......@@ -59,19 +59,17 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) {
return NOP();
}
// All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(name + kGradVarSuffix);
}
return NOP();
}
......@@ -134,9 +132,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
std::string prefix =
grad_input.substr(0, grad_input.size() - kGradVarSuffix.size());
grad_input = prefix + kZeroVarSuffix;
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
......@@ -147,7 +145,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME();
grad_output = kEmptyVarName;
}
}
......@@ -168,14 +166,14 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size());
no_grad_names.insert(OperatorBase::EMPTY_VAR_NAME() +
OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(kEmptyVarName + kGradVarSuffix);
for (auto& name : no_grad_vars) {
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
no_grad_names.insert(name + kGradVarSuffix);
}
size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid);
}
} // namespace framework
} // namespace paddle
......@@ -78,14 +78,14 @@ class FcOp : public ops::NetOp {
{Output("mul_result")}, {}));
auto b_name = Input("b");
std::string before_act = "mul_result";
if (b_name != EMPTY_VAR_NAME()) {
if (b_name != kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("add_result")}, {}));
before_act = "add_result";
} else {
auto out_varname = Output("add_result");
if (out_varname != EMPTY_VAR_NAME()) {
this->Rename(out_varname, EMPTY_VAR_NAME());
if (out_varname != kEmptyVarName) {
this->Rename(out_varname, kEmptyVarName);
}
}
......@@ -163,13 +163,12 @@ TEST(Backward, simple_op_grad) {
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(4UL, gop->inputs_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]);
ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix));
}
TEST(Backward, simple_op_not_need_grad) {
......@@ -177,7 +176,7 @@ TEST(Backward, simple_op_not_need_grad) {
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"X" + f::kGradVarSuffix),
gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
......@@ -210,9 +209,9 @@ TEST(Backward, net_fc_backward_normal) {
}
TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
{"mul_result", "add_result", "tmp"}, {});
std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName},
{"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
......@@ -242,24 +241,21 @@ TEST(Backward, net_input_of_network_not_need_grad) {
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
all_output.erase(f::kEmptyVarName);
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end());
}
// Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end());
ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ(f::kEmptyVarName,
first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
}
TEST(Backward, net_shared_weight) {
......@@ -311,17 +307,15 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ(1UL, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1UL, fill_zero.outputs_.size());
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]);
ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(),
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" + f::kGradVarSuffix));
ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" + f::kGradVarSuffix));
ASSERT_EQ("X" + f::kGradVarSuffix,
d_many_out.Output("x" + f::kGradVarSuffix));
}
TEST(Backward, op_part_of_input_are_not_need) {
......@@ -331,12 +325,10 @@ TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName);
ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
"out" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out");
......@@ -368,23 +360,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
/*
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("X"), "out2");
EXPECT_EQ(grad_fc.Input("W"), "w3");
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
EXPECT_EQ(grad_fc.Input("Out"), "out3");
*/
}
......@@ -56,8 +56,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
for (const auto& arg : src_arg_list) {
std::string src_name = arg.name();
std::string dst_name =
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name;
std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin =
......@@ -65,10 +64,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) {
std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX()
: arg.ignore_gradient()
? OperatorBase::EMPTY_VAR_NAME()
: src_inout[i];
std::string s =
is_grad ? src_inout[i] + kGradVarSuffix
: (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]);
dst_inout.emplace_back(s);
}
if (dst_format != nullptr) {
......
......@@ -83,24 +83,21 @@ TEST(GradOpBuilder, MutiInOut) {
EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_test_op->Input("Out1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Inputs("Out2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>(
{"out2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"out2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Input("Out1" + f::kGradVarSuffix),
"out1" + f::kGradVarSuffix);
EXPECT_EQ(grad_test_op->Inputs("Out2_mult" + f::kGradVarSuffix),
std::vector<std::string>(
{"out2_1" + f::kGradVarSuffix, "out2_2" + f::kGradVarSuffix}));
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_3" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Output("In3" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix),
"in1" + f::kGradVarSuffix);
EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix),
std::vector<std::string>({"in2_1" + f::kGradVarSuffix,
"in2_2" + f::kGradVarSuffix,
"in2_3" + f::kGradVarSuffix}));
EXPECT_EQ(grad_test_op->Output("In3" + f::kGradVarSuffix),
"in3" + f::kGradVarSuffix);
}
TEST(GradOpBuilder, IOIgnoredInGradient) {
......@@ -116,30 +113,25 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({f::OperatorBase::EMPTY_VAR_NAME(),
f::OperatorBase::EMPTY_VAR_NAME()}));
std::vector<std::string>({f::kEmptyVarName, f::kEmptyVarName}));
EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_test_op->Input("Out2"), f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(
grad_test_op->Inputs("Out1_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>(
{"out1_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"out1_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Input("Out2" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out2" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName);
EXPECT_EQ(grad_test_op->Inputs("Out1_mult" + f::kGradVarSuffix),
std::vector<std::string>(
{"out1_1" + f::kGradVarSuffix, "out1_2" + f::kGradVarSuffix}));
EXPECT_EQ(grad_test_op->Input("Out2" + f::kGradVarSuffix),
"out2" + f::kGradVarSuffix);
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"in1" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(
grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(
grad_test_op->Outputs("In3_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::vector<std::string>({"in3_1" + f::OperatorBase::GRAD_VAR_SUFFIX(),
"in3_2" + f::OperatorBase::GRAD_VAR_SUFFIX()}));
EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix),
"in1" + f::kGradVarSuffix);
EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix),
std::vector<std::string>(
{"in2_1" + f::kGradVarSuffix, "in2_2" + f::kGradVarSuffix}));
EXPECT_EQ(grad_test_op->Outputs("In3_mult" + f::kGradVarSuffix),
std::vector<std::string>(
{"in3_1" + f::kGradVarSuffix, "in3_2" + f::kGradVarSuffix}));
}
......@@ -341,7 +341,7 @@ class OpRegistry {
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
if (outname == OperatorBase::TMP_VAR_NAME()) {
if (outname == kTempVarName) {
outname += op->type_;
outname += "@";
outname += std::to_string(gUniqId.fetch_add(1));
......
......@@ -32,9 +32,29 @@ limitations under the License. */
namespace paddle {
namespace framework {
/// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@";
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@";
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
const std::string kZeroVarSuffix = "@ZERO";
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
class OperatorBase;
class InferShapeContext;
class ExecutionContext;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -43,25 +63,6 @@ class ExecutionContext;
*/
class OperatorBase {
public:
/// If a variable is a empty variable, that name will be used.
static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
static std::string GRAD_VAR_NAME(const std::string& name) {
return name + GRAD_VAR_SUFFIX();
}
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
virtual ~OperatorBase() {}
template <typename T>
......
......@@ -163,8 +163,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", OperatorBase::EMPTY_VAR_NAME)
.def("temp", OperatorBase::TMP_VAR_NAME);
.def("empty", []() { return kEmptyVarName; })
.def("temp", []() { return kTempVarName; });
// clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("create",
......
......@@ -27,7 +27,7 @@ public:
{Output("before_act")},
{}));
auto b = Input("b");
if (b != EMPTY_VAR_NAME()) {
if (b != framework::kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
......
......@@ -41,7 +41,7 @@ public:
class MeanGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......
......@@ -39,10 +39,10 @@ template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX());
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX());
auto IG = context.Output<Tensor>("X" + framework::kGradVarSuffix);
IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims());
......
......@@ -48,12 +48,12 @@ protected:
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr,
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(GRAD_VAR_NAME("Y"))->dims(),
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same");
ctx.Output<Tensor>(GRAD_VAR_NAME("X"))
ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};
......
......@@ -68,8 +68,8 @@ public:
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(OperatorBase::GRAD_VAR_NAME("Y"));
auto dX = context.Output<Tensor>(OperatorBase::GRAD_VAR_NAME("X"));
auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册