未验证 提交 31fee180 编写于 作者: S smallv0221 提交者: GitHub

Update lstm_unit_op.h

上级 1e661715
......@@ -39,8 +39,8 @@ template <typename DeviceContext, typename T>
class LstmUnitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(
platform::is_cpu_place(ctx.GetPlace()),
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet("It must use CPUPlace."));
auto* x_tensor = ctx.Input<framework::Tensor>("X");
......@@ -83,8 +83,8 @@ template <typename DeviceContext, typename T>
class LstmUnitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(
platform::is_cpu_place(ctx.GetPlace()),
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet("It must use CPUPlace."));
auto x_tensor = ctx.Input<Tensor>("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册