提交 f20fc955 编写于 作者: M minqiyang

Resize output ddims and rows

上级 67308822
......@@ -33,12 +33,14 @@ class ClipByNormKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<T>("max_norm");
auto in_var = context.InputVar("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
Tensor* output = nullptr;
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");
output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
} else if (in_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
......@@ -50,6 +52,11 @@ class ClipByNormKernel : public framework::OpKernel<T> {
merge_func(context.template device_context<DeviceContext>(), *x,
merged_input);
input = &(merged_input->value());
auto* output_selected_rows = context.Output<SelectedRows>("Out");
output_selected_rows->set_rows(merged_input.rows());
output = output_selected_rows->mutable_data();
output->Resize(framework::make_ddim(merged_input.value().dims()));
} else {
PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册