/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/operators/activation_op.h" #include "paddle/operators/math/math_function.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; template using EigenVector = framework::EigenVector; enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; template class GRUUnitKernel : public framework::OpKernel { public: template void ActCompute(const int act_type, const Device& d, X x, Y y) const { if (act_type == identity) y.device(d) = x; else if (act_type == sigmoid) SigmoidFunctor()(d, x, y); else if (act_type == tanh) TanhFunctor()(d, x, y); else if (act_type == relu) ReluFunctor()(d, x, y); else PADDLE_THROW("unsupported activation type"); } void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("Input"); auto* hidden_prev = context.Input("HiddenPrev"); auto* weight = context.Input("Weight"); auto* bias = context.Input("Bias"); auto* gate = context.Output("Gate"); gate->mutable_data(context.GetPlace()); auto* reset_hidden_prev = context.Output("ResetHiddenPrev"); reset_hidden_prev->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; auto x = EigenMatrix::From(*input); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto r_h_p = EigenMatrix::From(*reset_hidden_prev); auto h = EigenMatrix::From(*hidden); auto& place = *context.template device_context().eigen_device(); // calculate unactivated gate outputs if (bias) { auto b = EigenMatrix::From(*bias); g.device(place) = x + b.reshape(Eigen::array({{1, frame_size * 3}})) .broadcast(Eigen::array({{batch_size, 1}})); } else { g.device(place) = x; } const T* hidden_prev_data = hidden_prev->data(); const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); math::gemm( context.template device_context(), false, false, batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, weight_data, frame_size * 2, 1, gate_data, frame_size * 3); // calculate activited gate Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); ActCompute(context.Attr("gate_activation"), place, g.slice(u_offsets, extents), g.slice(u_offsets, extents)); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); ActCompute(context.Attr("gate_activation"), place, g.slice(r_offsets, extents), g.slice(r_offsets, extents)); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state math::gemm( context.template device_context(), false, false, batch_size, frame_size, frame_size, 1, reset_hidden_prev_data, frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1, gate_data + frame_size * 2, frame_size * 3); Eigen::array c_offsets({{0, frame_size * 2}}); ActCompute(context.Attr("activation"), place, g.slice(c_offsets, extents), g.slice(c_offsets, extents)); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output h.device(place) = u * (c - h_p) + h_p; } }; template class GRUUnitGradKernel : public framework::OpKernel { public: template void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx, DY dy) const { // x is dummy and won't be used even in Relu(use y instead) if (act_type == identity) dx.device(d) = dy; else if (act_type == sigmoid) SigmoidGradFunctor()(d, x, y, dy, dx); else if (act_type == tanh) TanhGradFunctor()(d, x, y, dy, dx); else if (act_type == relu) ReluGradFunctor()(d, x, y, dy, dx); else PADDLE_THROW("unsupported activation type"); } void Compute(const framework::ExecutionContext& context) const override { auto* input = context.Input("Input"); auto* hidden_prev = context.Input("HiddenPrev"); auto* weight = context.Input("Weight"); auto* gate = context.Input("Gate"); auto* reset_hidden_prev = context.Input("ResetHiddenPrev"); auto* hidden_grad = context.Input(framework::GradVarName("Hidden")); auto* input_grad = context.Output(framework::GradVarName("Input")); auto* hidden_prev_grad = context.Output(framework::GradVarName("HiddenPrev")); auto* weight_grad = context.Output(framework::GradVarName("Weight")); auto* bias_grad = context.Output(framework::GradVarName("Bias")); Tensor gate_grad; Tensor reset_hidden_prev_grad; const T* hidden_prev_data = hidden_prev->data(); const T* weight_data = weight->data(); T* gate_grad_data = gate_grad.mutable_data(input->dims(), context.GetPlace()); const T* reset_hidden_prev_data = reset_hidden_prev->data(); T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data( reset_hidden_prev->dims(), context.GetPlace()); auto h_p = EigenMatrix::From(*hidden_prev); auto g = EigenMatrix::From(*gate); auto d_h = EigenMatrix::From(*hidden_grad); auto d_g = EigenMatrix::From(gate_grad); auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); auto& place = *context.template device_context().eigen_device(); int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); auto r = g.slice(r_offsets, extents); // reset gate Eigen::array c_offsets({{0, frame_size * 2}}); auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate ActGradCompute(context.Attr("gate_activation"), place, u, u, d_g.slice(u_offsets, extents), d_h * (c - h_p)); // backward for unactivated output candidate ActGradCompute(context.Attr("activation"), place, c, c, d_g.slice(c_offsets, extents), d_h * u); // backward for reset_hidden_prev math::gemm( context.template device_context(), false, true, batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, 0, reset_hidden_prev_grad_data, frame_size); // backward for unactivated reset gate ActGradCompute(context.Attr("gate_activation"), place, r, r, d_g.slice(r_offsets, extents), d_r_h_p * h_p); // backward for weight if (weight_grad) { T* weight_grad_data = weight_grad->mutable_data(context.GetPlace()); // backward for state_weight math::gemm( context.template device_context(), true, false, frame_size, frame_size, batch_size, 1, reset_hidden_prev_data, frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0, weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for update_gate_weight and reset_gate_weight math::gemm( context.template device_context(), true, false, frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data, frame_size * 2); } // backward for hidden_prev if (hidden_prev_grad) { T* hidden_prev_grad_data = hidden_prev_grad->mutable_data(context.GetPlace()); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); math::gemm( context.template device_context(), false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, frame_size); } // backward for input if (input_grad) { input_grad->mutable_data(context.GetPlace()); auto d_x = EigenMatrix::From(*input_grad); d_x.device(place) = d_g; } // backward for bias if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); auto d_b = EigenVector::Flatten(*bias_grad); d_b.device(place) = d_g.sum(Eigen::array({{0}})); } } }; } // namespace operators } // namespace paddle