/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/operators/activation_op.h" #include "paddle/operators/math/math_function.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; template class GRUUnitKernel : public framework::OpKernel { public: template void ActCompute(const int act_type, const Device& d, X x, Y y) const { if (act_type == identity) y.device(d) = x; else if (act_type == sigmoid) SigmoidFunctor()(d, x, y); else if (act_type == tanh) TanhFunctor()(d, x, y); else if (act_type == relu) ReluFunctor()(d, x, y); else PADDLE_THROW("unsupported activation type"); } void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("Input"); auto* hidden_prev = context.Input("HiddenPrev"); auto* weight = context.Input("Weight"); auto* bias = context.Input("Bias"); auto* gate = context.Output("Gate"); gate->mutable_data(context.GetPlace()); auto* reset_hidden_prev = context.Output("ResetHiddenPrev"); reset_hidden_prev->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; auto x = EigenMatrix::From(*input); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto r_h_p = EigenMatrix::From(*reset_hidden_prev); auto h = EigenMatrix::From(*hidden); auto place = context.GetEigenDevice(); // calculate unactivated gate outputs if (bias) { auto b = EigenMatrix::From(*bias); g.device(place) = x + b.reshape(Eigen::array({{1, frame_size * 3}})) .broadcast(Eigen::array({{batch_size, 1}})); } else { g.device(place) = x; } const T* hidden_prev_data = hidden_prev->data(); const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); math::gemm(context.device_context(), false, false, batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, weight_data, frame_size * 2, 1, gate_data, frame_size * 3); // calculate activited gate Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); ActCompute(context.Attr("gate_activation"), place, g.slice(u_offsets, extents), g.slice(u_offsets, extents)); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); ActCompute(context.Attr("gate_activation"), place, g.slice(r_offsets, extents), g.slice(r_offsets, extents)); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state math::gemm(context.device_context(), false, false, batch_size, frame_size, frame_size, 1, reset_hidden_prev_data, frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1, gate_data + frame_size * 2, frame_size * 3); Eigen::array c_offsets({{0, frame_size * 2}}); ActCompute(context.Attr("activation"), place, g.slice(c_offsets, extents), g.slice(c_offsets, extents)); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output h.device(place) = u * (h_p - c) + c; } }; template class GRUUnitGradKernel : public framework::OpKernel { public: template void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx, DY dy) const { // x is dummy and won't be used even in Relu(use y instead) if (act_type == identity) dx.device(d) = dy; else if (act_type == sigmoid) SigmoidGradFunctor()(d, x, y, dy, dx); else if (act_type == tanh) TanhGradFunctor()(d, x, y, dy, dx); else if (act_type == relu) ReluGradFunctor()(d, x, y, dy, dx); else PADDLE_THROW("unsupported activation type"); } void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("Input"); auto* hidden_prev = context.Input("HiddenPrev"); auto* weight = context.Input("Weight"); auto* gate = context.Input("Gate"); auto* reset_hidden_prev = context.Input("ResetHiddenPrev"); auto* hidden_grad = context.Input(framework::GradVarName("Hidden")); auto* input_grad = context.Output(framework::GradVarName("Input")); auto* hidden_prev_grad = context.Output(framework::GradVarName("HiddenPrev")); auto* weight_grad = context.Output(framework::GradVarName("Weight")); auto* bias_grad = context.Output(framework::GradVarName("Bias")); input_grad->mutable_data(context.GetPlace()); hidden_prev_grad->mutable_data(context.GetPlace()); weight_grad->mutable_data(context.GetPlace()); Tensor gate_grad; gate_grad.mutable_data(input->dims(), context.GetPlace()); Tensor reset_hidden_prev_grad; reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), context.GetPlace()); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; const T* hidden_prev_data = hidden_prev->data(); T* hidden_prev_grad_data = hidden_prev_grad->data(); const T* weight_data = weight->data(); T* weight_grad_data = weight_grad->data(); T* gate_grad_data = gate_grad.data(); const T* reset_hidden_prev_data = reset_hidden_prev->data(); T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto d_h = EigenMatrix::From(*hidden_grad); auto d_x = EigenMatrix::From(*input_grad); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); auto d_g = EigenMatrix::From(gate_grad); auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); auto place = context.GetEigenDevice(); Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); auto r = g.slice(r_offsets, extents); // reset gate Eigen::array c_offsets({{0, frame_size * 2}}); auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate ActGradCompute(context.Attr("gate_activation"), place, u, u, d_g.slice(u_offsets, extents), d_h * (h_p - c)); // backward for unactivated output candidate ActGradCompute(context.Attr("activation"), place, c, c, d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u)); // backward for reset_hidden_prev math::gemm(context.device_context(), false, true, batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, 0, reset_hidden_prev_grad_data, frame_size); // backward for state_weight math::gemm( context.device_context(), true, false, frame_size, frame_size, batch_size, 1, reset_hidden_prev_data, frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0, weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for unactivated reset gate ActGradCompute(context.Attr("gate_activation"), place, r, r, d_g.slice(r_offsets, extents), d_r_h_p * h_p); // backward for update_gate_weight and reset_gate_weight math::gemm(context.device_context(), true, false, frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data, frame_size * 2); // backward for hidden_prev d_h_p.device(place) = d_r_h_p * r + d_h * u; math::gemm(context.device_context(), false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, frame_size); // backward for input d_x.device(place) = d_g; // backward for bias if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); auto d_b = EigenMatrix::From(*bias_grad); d_b.device(place) = d_g.sum(Eigen::array({{0}})); } } }; } // namespace operators } // namespace paddle