提交 2a36ad1a 编写于 作者: Y Yu Yang

Handle LoD for concat & seq_softmax ops

上级 211d8186
......@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
{
auto dx = outs;
auto x = ins;
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
}
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// get output tensor that the name is not kEmptyVarName
......
......@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
......
......@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace operators {
......@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs);
};
......
......@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod();
const size_t level = lod.size() - 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册