提交 6e964ad5 编写于 作者: W wanghaoshuang

Fix issues

上级 320df7ad
...@@ -26,8 +26,8 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -26,8 +26,8 @@ class ClipOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto max = GetAttr<float>("max"); auto max = Attr<float>("max");
auto min = GetAttr<float>("min"); auto min = Attr<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<Tensor>("Out")->Resize(x_dims); ctx.Output<Tensor>("Out")->Resize(x_dims);
} }
......
...@@ -34,8 +34,8 @@ template <typename T> ...@@ -34,8 +34,8 @@ template <typename T>
class ClipGradientOpCUDAKernel : public framework::OpKernel { class ClipGradientOpCUDAKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max"); auto max = context.op().Attr<float>("max");
auto min = context.op().GetAttr<float>("min"); auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Output<Tensor>("X"); auto* x = context.Output<Tensor>("X");
......
...@@ -30,8 +30,8 @@ template <typename Place, typename T> ...@@ -30,8 +30,8 @@ template <typename Place, typename T>
class ClipKernel : public framework::OpKernel { class ClipKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max"); auto max = context.op().Attr<float>("max");
auto min = context.op().GetAttr<float>("min"); auto min = context.op().Attr<float>("min");
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -46,8 +46,8 @@ template <typename T> ...@@ -46,8 +46,8 @@ template <typename T>
class ClipGradKernel : public framework::OpKernel { class ClipGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max"); auto max = context.op().Attr<float>("max");
auto min = context.op().GetAttr<float>("min"); auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Output<Tensor>("X"); auto* x = context.Output<Tensor>("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册