未验证 提交 8696335f 编写于 作者: L Leo Chen 提交者: GitHub

Fix dtype of ungenerated grad var (#28511)

* fix dtype of ungenerated grad var

* update ut

* refine code

* set default dtype

* fix could_use_cudnn bug

* remove debug code

* re-implement

* fix bug
上级 03e07273
......@@ -99,9 +99,15 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
}
if (tensor && !tensor->IsInitialized()) {
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
tensor->mutable_data(op.place(), var->DataType());
// NOTE(zhiqiu): since grad variable is ungenerated, so the dtype is not
// correct. var->DataType() returns the default dtype, which is float32.
// Here, we use the type of the corresponding forward datatype.
tensor->mutable_data(op.place(), var->ForwardDataType());
VLOG(6) << "Set ungenerated Grad: " << var->Name()
<< " as zero with dtype "
<< framework::DataTypeToString(var->ForwardDataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
}
}
......
......@@ -384,6 +384,16 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
}
VLOG(4) << LayerDebugString(op.Type(), ins, outs);
// set the output var
for (auto& var_pair : outs) {
for (auto& var : var_pair.second) {
// NOTE(zhiqu): The ouput may be NULL because of pruning.
if (var) {
SetForwardDataTypeOfGradVar(var);
}
}
}
}
void OpBase::Run(const framework::OperatorBase& op,
......
......@@ -50,7 +50,7 @@ void SetForwardDataTypeOfGradVar<VariableWrapper>(
const std::shared_ptr<VariableWrapper>& var) {
if (var->HasGradVar()) {
auto grad_var = var->GetGradVar();
VLOG(6) << "Set grad var (" << grad_var->Name() << ") dtype to ("
VLOG(6) << "Set grad var (" << grad_var->Name() << ")'s forward dtype to ("
<< framework::DataTypeToString(var->DataType()) << ").";
grad_var->SetForwardDataType(var->DataType());
}
......
......@@ -241,9 +241,10 @@ class VariableWrapper {
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
if (shared_var != var) {
PADDLE_ENFORCE_EQ(shared_var, nullptr,
platform::errors::PermissionDenied(
"Cannot set gradient var wrapper twice"));
PADDLE_ENFORCE_EQ(
shared_var, nullptr,
platform::errors::PermissionDenied(
"Cannot set gradient variable wrapper twice for %s", name_));
grad_var_ = var;
}
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -552,8 +553,12 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
}
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
// maybe need check exist
auto *in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
// need check exist
T *in_grad_data = nullptr;
if (in_grad) {
in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
}
bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength;
......@@ -583,40 +588,52 @@ class RNNGradCudnnKernel : public framework::OpKernel<T> {
const uint8_t *reserve_data = reserve->data<uint8_t>();
if (!has_seq_length) {
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
if (in_grad) {
// This interface is used when the input/output is unpadded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
}
if (!weight_grad_list.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data),
reserve_size));
}
} else {
#if CUDNN_VERSION >= 7201
// for train
// This interface is used when the input/output is padded.
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
if (in_grad) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data,
rnn.y_seq_desc(), out_grad_data, nullptr, nullptr,
rnn.last_h_desc(), last_h_grad_data, rnn.last_c_desc(),
last_c_grad_data, rnn.weight_desc(), weight_data, rnn.init_h_desc(),
init_h_data, rnn.init_c_desc(), init_c_data, rnn.x_seq_desc(),
in_grad_data, rnn.init_h_desc(), init_h_grad_data,
rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
workspace_data_.data<uint8_t>(), workspace_size,
const_cast<uint8_t *>(reserve_data), reserve_size));
}
if (!weight_grad_list.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnRNNBackwardWeightsEx(
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h_data, rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad_data,
const_cast<uint8_t *>(reserve_data), reserve_size));
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
......
......@@ -58,6 +58,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......@@ -87,6 +88,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......@@ -134,6 +136,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"rnn", {"DropoutState"}},
};
// clang-format off
......
......@@ -272,6 +272,7 @@ class TestLSTM(unittest.TestCase):
def test_predict(self):
predict_test_util(self.place, "LSTM")
predict_test_util(self.place, "LSTM", False)
def runTest(self):
self.test_with_initial_state()
......@@ -280,7 +281,7 @@ class TestLSTM(unittest.TestCase):
self.test_predict()
def predict_test_util(place, mode):
def predict_test_util(place, mode, stop_gradient=True):
place = paddle.set_device(place)
paddle.seed(123)
np.random.seed(123)
......@@ -298,7 +299,7 @@ def predict_test_util(place, mode):
return self.rnn(input)
x = paddle.randn((4, 10, 16))
x.stop_gradient = False
x.stop_gradient = stop_gradient
seq_len = paddle.to_tensor(np.array([10, 6, 8, 5]))
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
mask = paddle.unsqueeze(mask, [2])
......
......@@ -989,39 +989,50 @@ class RNNBase(LayerList):
def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])
out = self._helper.create_variable_for_type_inference(inputs.dtype)
state = [
self._helper.create_variable_for_type_inference(inputs.dtype)
for i in range(self.state_components)
]
reserve = self._helper.create_variable_for_type_inference(
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True)
inputs = {
'Input': inputs,
'WeightList': self._all_weights,
'PreState': initial_states,
'SequenceLength': sequence_length
}
attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.num_directions == 2,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': not self.training
}
outputs = {
'Out': out,
'State': state,
'Reserve': reserve,
'DropoutState': self._dropout_state,
}
if fluid.framework.in_dygraph_mode():
_, _, out, state = framework.core.ops.rnn(
inputs, initial_states, self._all_weights, sequence_length,
self._dropout_state, self.state_components, 'dropout_prob',
self.dropout, 'is_bidirec', self.num_directions == 2,
'input_size', self.input_size, 'hidden_size', self.hidden_size,
'num_layers', self.num_layers, 'mode', self.mode, 'is_test',
not self.training)
else:
out = self._helper.create_variable_for_type_inference(inputs.dtype)
state = [
self._helper.create_variable_for_type_inference(inputs.dtype)
for i in range(self.state_components)
]
reserve = self._helper.create_variable_for_type_inference(
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True)
inputs = {
'Input': inputs,
'WeightList': self._all_weights,
'PreState': initial_states,
'SequenceLength': sequence_length
}
attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.num_directions == 2,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': not self.training
}
outputs = {
'Out': out,
'State': state,
'Reserve': reserve,
'DropoutState': self._dropout_state,
}
self._helper.append_op(
type="rnn", inputs=inputs, outputs=outputs, attrs=attrs)
self._helper.append_op(
type="rnn", inputs=inputs, outputs=outputs, attrs=attrs)
out = paddle.tensor.transpose(out,
[1, 0, 2]) if not self.time_major else out
return out, tuple(state) if len(state) > 1 else state[0]
......@@ -1032,15 +1043,15 @@ class RNNBase(LayerList):
if initial_states is None:
state_shape = (self.num_layers * self.num_directions, -1,
self.hidden_size)
if self.state_components == 1:
initial_states = paddle.fluid.layers.fill_constant_batch_size_like(
initial_states = tuple([
paddle.fluid.layers.fill_constant_batch_size_like(
inputs, state_shape, dtype, 0, batch_index, 1)
else:
initial_states = tuple([
paddle.fluid.layers.fill_constant_batch_size_like(
inputs, state_shape, dtype, 0, batch_index, 1)
for _ in range(self.state_components)
])
for _ in range(self.state_components)
])
else:
initial_states = [initial_states] if isinstance(
initial_states,
paddle.fluid.framework.Variable) else initial_states
if self.could_use_cudnn:
# Add CPU kernel and dispatch in backend later
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册