提交 54f0d260 编写于 作者: D dangqingqing

fix input size.

上级 bd01cea1
......@@ -38,7 +38,8 @@ public:
class SigmoidOpGrad : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1,
// need to check input size 2 or 3, (dY, Y) or (dY, Y, X)
PADDLE_ENFORCE(ctx.InputSize() == 2,
"Sigmoid Gradient Op only have one input");
PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Sigmoid Gradient Op only have one output");
......
......@@ -27,6 +27,7 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
......@@ -37,7 +38,7 @@ template <typename Place, typename T>
class SigmoidGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
// TODO(qingqing) maybe a helper funciton is needed fo the name x@GRAD
// maybe a helper funciton is needed fo the name x@GRAD
auto y_t = context.Input<Tensor>("Y");
auto dy_t = context.Input<Tensor>("Y@GRAD");
auto dx_t = context.Output<Tensor>("X@GRAD");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册