提交 e53dc8a2 编写于 作者: W whs 提交者: GitHub

Merge pull request #3937 from wanghaoshuang/clip_op

Add clip op
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/clip_op.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
class ClipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of ClipOp should not be null.");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto max = Attr<float>("max");
auto min = Attr<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<LoDTensor>("Out")->Resize(x_dims);
}
};
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip op."
"The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out", "(Tensor)The output of clip op with shape as input(X)");
AddAttr<AttrType>(
"min", "(float)Minimum value, under which element is replaced by min.");
AddAttr<AttrType>(
"max", "(float)Maximum value, above which element is replaced by max");
AddComment(R"DOC(
Clip operator limits the given input within an interval. The interval is
specified with arguments 'min' and 'max'.
)DOC");
}
};
class ClipOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Transform;
template <typename T>
class ClipFunctor {
public:
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const {
if (x < min_)
return min_;
else if (x > max_)
return max_;
else
return x;
}
private:
T min_;
T max_;
};
template <typename T>
class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0;
}
private:
T min_;
T max_;
};
template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
}
};
template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<Tensor>("X");
int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform<Place> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestClipOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input[np.abs(input - self.min) < self.max_relative_error] = 0.5
input[np.abs(input - self.max) < self.max_relative_error] = 0.5
self.op_type = "clip"
self.inputs = {'X': input, }
self.attrs = {}
self.attrs['min'] = self.min
self.attrs['max'] = self.max
self.outputs = {
'Out': np.clip(self.inputs['X'], self.attrs['min'],
self.attrs['max'])
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X'], 'Out', max_relative_error=self.max_relative_error)
def initTestCase(self):
self.shape = (4, 4)
self.max = 0.7
self.min = 0.1
class TestCase1(TestClipOp):
def initTestCase(self):
self.shape = (8, 16, 8)
self.max = 0.7
self.min = 0
class TestCase2(TestClipOp):
def initTestCase(self):
self.shape = (8, 16)
self.max = 1
self.min = 0
class TestCase3(TestClipOp):
def initTestCase(self):
self.shape = (4, 8, 16)
self.max = 0.7
self.min = 0.2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册