提交 f6b518c9 编写于 作者: D dangqingqing

Fix elementwise_mul_op.cc

上级 cb284283
...@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel { ...@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
auto y_dim = ctx.Input<Tensor>("Y")->dims(); auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
ctx.Output<framework::Tensor>("Out")->Resize(x_dim); ctx.Output<framework::LoDTensor>("Out")->Resize(x_dim);
} }
}; };
...@@ -80,8 +80,10 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel { ...@@ -80,8 +80,10 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *x_grad =
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
......
...@@ -176,10 +176,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -176,10 +176,6 @@ All parameter, weight, gradient are variables in Paddle.
.def("set_int", .def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; }) [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); }) .def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
// .def("get_tensor",
// [](Variable &self) -> Tensor * { return
// self.GetMutable<Tensor>(); },
// py::return_value_policy::reference)
.def("get_tensor", .def("get_tensor",
[](Variable &self) -> LoDTensor * { [](Variable &self) -> LoDTensor * {
return self.GetMutable<LoDTensor>(); return self.GetMutable<LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册