提交 a0fb6464 编写于 作者: K Kavya Srinet

Fixed one_hot_op.cu

上级 93d28850
......@@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册