提交 54f0d260 编写于 作者: D dangqingqing

fix input size.

上级 bd01cea1
...@@ -38,7 +38,8 @@ public: ...@@ -38,7 +38,8 @@ public:
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, // need to check input size 2 or 3, (dY, Y) or (dY, Y, X)
PADDLE_ENFORCE(ctx.InputSize() == 2,
"Sigmoid Gradient Op only have one input"); "Sigmoid Gradient Op only have one input");
PADDLE_ENFORCE(ctx.OutputSize() == 1, PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Sigmoid Gradient Op only have one output"); "Sigmoid Gradient Op only have one output");
......
...@@ -27,6 +27,7 @@ public: ...@@ -27,6 +27,7 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
...@@ -37,7 +38,7 @@ template <typename Place, typename T> ...@@ -37,7 +38,7 @@ template <typename Place, typename T>
class SigmoidGradKernel : public OpKernel { class SigmoidGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
// TODO(qingqing) maybe a helper funciton is needed fo the name x@GRAD // maybe a helper funciton is needed fo the name x@GRAD
auto y_t = context.Input<Tensor>("Y"); auto y_t = context.Input<Tensor>("Y");
auto dy_t = context.Input<Tensor>("Y@GRAD"); auto dy_t = context.Input<Tensor>("Y@GRAD");
auto dx_t = context.Output<Tensor>("X@GRAD"); auto dx_t = context.Output<Tensor>("X@GRAD");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册