未验证 提交 d8fe517b 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Support for SelectedRows for Transpose OP and Fix a Bug That SelectedRows...

Add Support for SelectedRows for Transpose OP and Fix a Bug That SelectedRows Cannot be Supported in SimNet (#25536)

This PR fixes a bug that SelectedRows cannot be supported in SimNet. The reason of this bug is that dygraph basic_engine didn't copy var's type when the var needs to be accumulated during backward. So when a var is SelectedRows and needs to be accumulated, like SimNet which calls net for two times, the var's type will be changed to default LoDTensor thus bug happens. To fix it, we just also copy the type.

Without this PR, the accumulated SelectedRows parameters in dygraph will be changed into LoDTensor. So when we fixed the bug of supporting SelectedRows in SimNet, we found `test_imperative_lod_tensor_to_selected_rows` failed and threw the error that SelectedRows was not supported for Transpose OP. To fix it, too, this PR also added support for SelectedRows for Transpose OP.
上级 0f8dc611
......@@ -205,7 +205,9 @@ void BasicEngine::Execute() {
continue;
}
var = std::make_shared<VariableWrapper>(var->Name());
auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
tmp_var->SetType(var->Type());
var = tmp_var;
need_accu_var_list_.emplace_back(iter->second.get(), var);
}
}
......
......@@ -660,19 +660,26 @@ template <typename DeviceContext, typename T>
class TransposeGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret = TransposeSimple<T>::run(dev_ctx, *x, axis, out);
auto ret = TransposeSimple<T>::run(dev_ctx, *x_tensor, axis, out_tensor);
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x, out, axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor,
axis);
}
}
};
......@@ -680,14 +687,19 @@ template <typename DeviceContext, typename T>
class TransposeGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
x_grad->mutable_data<T>(context.GetPlace());
if (x_grad->numel() == 0) {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
......@@ -699,11 +711,11 @@ class TransposeGradGPUKernel : public framework::OpKernel<T> {
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret =
TransposeSimple<T>::run(dev_ctx, *out_grad, reversed_axis, x_grad);
auto ret = TransposeSimple<T>::run(dev_ctx, *out_grad_tensor, reversed_axis,
x_grad_tensor);
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad, x_grad,
reversed_axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor,
x_grad_tensor, reversed_axis);
}
}
};
......
......@@ -64,16 +64,23 @@ template <typename DeviceContext, typename T>
class TransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x, out, axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor, axis);
}
};
......@@ -81,14 +88,19 @@ template <typename DeviceContext, typename T>
class TransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
x_grad->mutable_data<T>(context.GetPlace());
if (x_grad->numel() == 0) {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
......@@ -101,8 +113,8 @@ class TransposeGradKernel : public framework::OpKernel<T> {
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad, x_grad,
reversed_axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor,
x_grad_tensor, reversed_axis);
}
};
......
......@@ -42,7 +42,7 @@ class EmbeddingLayer(object):
# causes crush in dy2stat. Set it to True after fixing it.
emb = Embedding(
size=[self.dict_size, self.emb_dim],
is_sparse=False,
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=attr.ParamAttr(
name=self.name, initializer=fluid.initializer.Xavier()))
......
......@@ -149,7 +149,6 @@ def train(conf_dict, to_static):
pred = pos_score
_, neg_score = net(left, neg_right)
avg_cost = loss.compute(pos_score, neg_score)
#avg_cost = loss.compute(pos_score, pos_score)
losses.append(np.mean(avg_cost.numpy()))
avg_cost.backward()
optimizer.minimize(avg_cost)
......
......@@ -186,7 +186,8 @@ class TestDygraphSimpleNet(unittest.TestCase):
k - 1]] = out[k]
self.assertTrue(
np.array_equal(static_loss_value, dy_loss_value))
np.allclose(
static_loss_value, dy_loss_value, rtol=1e-3))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册