未验证 提交 b06aee7b 编写于 作者: Z Zhanghuihong Guan 提交者: GitHub

Transfer python assertion in PReLU to c++ (#6261)

* removed assertion for PReLU, need to put the assertion in c implementation

* added assertion in PReLU for parameter validation

* modifications based on review

* removed inheritance relation based on review

* formatted code

* moved a helper function to anonymous namespace

* format code
Co-authored-by: NYao Chi <later@usopp.net>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 8ff6c8af
......@@ -61,11 +61,30 @@ class ReluGradFunctor : public BinaryFunctor {
}
};
class PReluFunctor : public BinaryFunctor {
namespace {
Maybe<void> CheckPReLUParametersValid(const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& alpha) {
int num_params = alpha->dim(0);
CHECK_OR_RETURN(((num_params == 1) || (num_params == x->shape()->At(1))))
<< "num_parameters in prelu must be 1 or " << x->shape()->At(1);
return Maybe<void>::Ok();
}
} // namespace
class PReluFunctor {
public:
PReluFunctor() {
op_ = CHECK_JUST(one::OpBuilder("prelu").Input("x").Input("alpha").Output("y").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& alpha) const {
JUST(CheckPReLUParametersValid(x, alpha));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, alpha});
}
private:
std::shared_ptr<OpExpr> op_;
};
class PReluGradFunctor {
......
......@@ -92,9 +92,6 @@ class PReLU(Module):
self.weight = flow.nn.Parameter(flow.Tensor(num_parameters).fill_(init))
def forward(self, x):
assert (
self.num_parameters == 1 or self.num_parameters == x.shape[1]
), f"num_parameters in prelu must be 1 or {x.shape[1]}"
return flow._C.prelu(x, self.weight)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册