提交 6e964ad5 编写于 作者: W wanghaoshuang

Fix issues

上级 320df7ad
......@@ -26,8 +26,8 @@ class ClipOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto max = GetAttr<float>("max");
auto min = GetAttr<float>("min");
auto max = Attr<float>("max");
auto min = Attr<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<Tensor>("Out")->Resize(x_dims);
}
......
......@@ -34,8 +34,8 @@ template <typename T>
class ClipGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max");
auto min = context.op().GetAttr<float>("min");
auto max = context.op().Attr<float>("max");
auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Output<Tensor>("X");
......
......@@ -30,8 +30,8 @@ template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max");
auto min = context.op().GetAttr<float>("min");
auto max = context.op().Attr<float>("max");
auto min = context.op().Attr<float>("min");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
......@@ -46,8 +46,8 @@ template <typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().GetAttr<float>("max");
auto min = context.op().GetAttr<float>("min");
auto max = context.op().Attr<float>("max");
auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* x = context.Output<Tensor>("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册