提交 74f460fd 编写于 作者: D dangqingqing

Fix specialization of template member functions in the non-template class in GCC 5.0.

上级 f6b518c9
......@@ -209,8 +209,7 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name);
if (var == nullptr) return nullptr;
return GetTensorFromVar(var);
return var == nullptr ? nullptr : const_cast<Tensor*>(GetTensorFromVar(var));
}
template <>
......@@ -222,7 +221,9 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope().FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
return var == nullptr
? nullptr
: const_cast<Tensor*>(GetTensorFromVar(var));
});
return res;
}
......
......@@ -327,13 +327,13 @@ class InferShapeContext {
return res;
}
Tensor* GetTensorFromVar(const Variable* var) const {
const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
return &var->Get<Tensor>();
}
private:
......@@ -341,6 +341,13 @@ class InferShapeContext {
const Scope& scope_;
};
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const;
template <typename T>
struct EigenDeviceConverter;
......@@ -397,6 +404,13 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext* device_context_;
};
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
class OpKernel {
public:
/**
......
......@@ -66,7 +66,7 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<framework::LoDTensor>("X")->dims();
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
"The rank of output grad must equal to Input(X).");
for (size_t i = 1; i < og_dims.size(); ++i) {
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
}
auto* x_grad =
......
......@@ -38,7 +38,7 @@ class SequenceAvgPoolKernel : public framework::OpKernel {
out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) {
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1);
......@@ -64,7 +64,7 @@ class SequenceAvgPoolGradKernel : public framework::OpKernel {
in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) {
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
auto out_g_t = out_g->Slice<T>(i, i + 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册