提交 74f460fd 编写于 作者: D dangqingqing

Fix specialization of template member functions in the non-template class in GCC 5.0.

上级 f6b518c9
...@@ -209,8 +209,7 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>( ...@@ -209,8 +209,7 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name); auto* var = OutputVar(name);
if (var == nullptr) return nullptr; return var == nullptr ? nullptr : const_cast<Tensor*>(GetTensorFromVar(var));
return GetTensorFromVar(var);
} }
template <> template <>
...@@ -222,7 +221,9 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -222,7 +221,9 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope().FindVar(sub_name); auto var = scope().FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var); return var == nullptr
? nullptr
: const_cast<Tensor*>(GetTensorFromVar(var));
}); });
return res; return res;
} }
......
...@@ -327,13 +327,13 @@ class InferShapeContext { ...@@ -327,13 +327,13 @@ class InferShapeContext {
return res; return res;
} }
Tensor* GetTensorFromVar(const Variable* var) const { const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>()); return &var->Get<LoDTensor>();
} }
PADDLE_ENFORCE(var->IsType<Tensor>(), PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor."); "The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>()); return &var->Get<Tensor>();
} }
private: private:
...@@ -341,6 +341,13 @@ class InferShapeContext { ...@@ -341,6 +341,13 @@ class InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const;
template <typename T> template <typename T>
struct EigenDeviceConverter; struct EigenDeviceConverter;
...@@ -397,6 +404,13 @@ class ExecutionContext : public InferShapeContext { ...@@ -397,6 +404,13 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext* device_context_; const platform::DeviceContext* device_context_;
}; };
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
class OpKernel { class OpKernel {
public: public:
/** /**
......
...@@ -66,7 +66,7 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { ...@@ -66,7 +66,7 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<framework::LoDTensor>("X")->dims(); auto x_dims = ctx.Input<framework::LoDTensor>("X")->dims();
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
"The rank of output grad must equal to Input(X)."); "The rank of output grad must equal to Input(X).");
for (size_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
} }
auto* x_grad = auto* x_grad =
......
...@@ -38,7 +38,7 @@ class SequenceAvgPoolKernel : public framework::OpKernel { ...@@ -38,7 +38,7 @@ class SequenceAvgPoolKernel : public framework::OpKernel {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) { for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]), Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1])); static_cast<int>(lod[0][i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1); Tensor out_t = out->Slice<T>(i, i + 1);
...@@ -64,7 +64,7 @@ class SequenceAvgPoolGradKernel : public framework::OpKernel { ...@@ -64,7 +64,7 @@ class SequenceAvgPoolGradKernel : public framework::OpKernel {
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) { for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]), auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1])); static_cast<int>(lod[0][i + 1]));
auto out_g_t = out_g->Slice<T>(i, i + 1); auto out_g_t = out_g->Slice<T>(i, i + 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册