/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/operators/math/math_function.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; template class GRUUnitKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("input"); auto* hidden_prev = context.Input("hidden_prev"); auto* weight = context.Input("weight"); auto* bias = context.Input("bias"); auto* gate = context.Output("gate"); gate->mutable_data(context.GetPlace()); auto* reset_hidden_prev = context.Output("reset_hidden_prev"); reset_hidden_prev->mutable_data(context.GetPlace()); auto* hidden = context.Output("hidden"); hidden->mutable_data(context.GetPlace()); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; auto x = EigenMatrix::From(*input); auto h_p = EigenMatrix::From(*hidden_prev); auto b = EigenMatrix::From(*bias); auto g = EigenMatrix::From(*gate); auto r_h_p = EigenMatrix::From(*reset_hidden_prev); auto h = EigenMatrix::From(*hidden); auto place = context.GetEigenDevice(); // calculate unactivated gate outputs g.device(place) = x + b.reshape(Eigen::array({{1, frame_size * 3}})) .broadcast(Eigen::array({{batch_size, 1}})); const T* hidden_prev_data = hidden_prev->data(); const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); math::gemm(context.device_context(), false, false, batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, weight_data, frame_size * 2, 1, gate_data, frame_size * 3); // calculate activited gate Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); g.slice(u_offsets, extents).device(place) = g.slice(u_offsets, extents).sigmoid(); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); g.slice(r_offsets, extents).device(place) = g.slice(r_offsets, extents).sigmoid(); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state math::gemm(context.device_context(), false, false, batch_size, frame_size, frame_size, 1, reset_hidden_prev_data, frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1, gate_data + frame_size * 2, frame_size * 3); Eigen::array c_offsets({{0, frame_size * 2}}); g.slice(c_offsets, extents).device(place) = g.slice(c_offsets, extents).tanh(); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output h.device(place) = u * (h_p - c) + c; } }; template class GRUUnitGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("input"); auto* hidden_prev = context.Input("hidden_prev"); auto* weight = context.Input("weight"); auto* gate = context.Input("gate"); auto* reset_hidden_prev = context.Input("reset_hidden_prev"); auto* hidden_grad = context.Input(framework::GradVarName("hidden")); auto* input_grad = context.Output(framework::GradVarName("input")); auto* hidden_prev_grad = context.Output(framework::GradVarName("hidden_prev")); auto* weight_grad = context.Output(framework::GradVarName("weight")); auto* bias_grad = context.Output(framework::GradVarName("bias")); input_grad->mutable_data(context.GetPlace()); hidden_prev_grad->mutable_data(context.GetPlace()); weight_grad->mutable_data(context.GetPlace()); bias_grad->mutable_data(context.GetPlace()); Tensor gate_grad; gate_grad.mutable_data(input->dims(), context.GetPlace()); Tensor reset_hidden_prev_grad; reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), context.GetPlace()); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; const T* hidden_prev_data = hidden_prev->data(); T* hidden_prev_grad_data = hidden_prev_grad->data(); const T* weight_data = weight->data(); T* weight_grad_data = weight_grad->data(); T* gate_grad_data = gate_grad.data(); const T* reset_hidden_prev_data = reset_hidden_prev->data(); T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto d_h = EigenMatrix::From(*hidden_grad); auto d_x = EigenMatrix::From(*input_grad); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); auto d_b = EigenMatrix::From(*bias_grad); auto d_g = EigenMatrix::From(gate_grad); auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); auto place = context.GetEigenDevice(); Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); auto r = g.slice(r_offsets, extents); // reset gate Eigen::array c_offsets({{0, frame_size * 2}}); auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate d_g.slice(u_offsets, extents).device(place) = d_h * (h_p - c) * u * (u.constant(T(1)) - u); // backward for unactivated output candidate d_g.slice(c_offsets, extents).device(place) = d_h * (u.constant(T(1)) - u) * (c.constant(T(1)) - c * c); // backward for reset_hidden_prev math::gemm(context.device_context(), false, true, batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, 0, reset_hidden_prev_grad_data, frame_size); // backward for state_weight math::gemm( context.device_context(), true, false, frame_size, frame_size, batch_size, 1, reset_hidden_prev_data, frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0, weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for unactivated reset gate d_g.slice(r_offsets, extents).device(place) = d_r_h_p * h_p * r * (r.constant(T(1)) - r); // backward for update_gate_weight and reset_gate_weight math::gemm(context.device_context(), true, false, frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data, frame_size * 2); // backward for hidden_prev d_h_p.device(place) = d_r_h_p * r + d_h * u; math::gemm(context.device_context(), false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, frame_size); // backward for input d_x.device(place) = d_g; // backward for bias d_b.device(place) = d_g.sum(Eigen::array({{0}})); } }; } // namespace operators } // namespace paddle