提交 87da1542 编写于 作者: Y Yu Yang

FIx sigmoid_xe_with_logits_op compile

上级 63469dae
......@@ -21,7 +21,7 @@ namespace operators {
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
......@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
// dX = sigmoid(X) - labels
template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel {
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册