提交 55991822 编写于 作者: X xzl

modify GetAttr to Attr

上级 828008e4
...@@ -28,7 +28,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -28,7 +28,7 @@ class TransposeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto in_dim = ctx.Input<Tensor>("X")->dims(); auto in_dim = ctx.Input<Tensor>("X")->dims();
auto axis = ctx.GetAttr<std::vector<int>>("axis"); auto axis = ctx.Attr<std::vector<int>>("axis");
size_t in_dim_size = in_dim.size(); size_t in_dim_size = in_dim.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
......
...@@ -98,7 +98,7 @@ class TransposeCUDAKernel : public framework::OpKernel { ...@@ -98,7 +98,7 @@ class TransposeCUDAKernel : public framework::OpKernel {
"It must use GPUPlace."); "It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto axis = context.GetAttr<std::vector<int>>("axis"); auto axis = context.Attr<std::vector<int>>("axis");
TransposeCUDA<T>(context, *in, *out, axis); TransposeCUDA<T>(context, *in, *out, axis);
} }
}; };
...@@ -111,7 +111,7 @@ class TransposeGradCUDAKernel : public framework::OpKernel { ...@@ -111,7 +111,7 @@ class TransposeGradCUDAKernel : public framework::OpKernel {
"It must use GPUPlace."); "It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out")); auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto axis_temp = context.GetAttr<std::vector<int>>("axis"); auto axis_temp = context.Attr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp); std::vector<int> axis(axis_temp);
......
...@@ -77,7 +77,7 @@ class TransposeKernel : public framework::OpKernel { ...@@ -77,7 +77,7 @@ class TransposeKernel : public framework::OpKernel {
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto axis = context.GetAttr<std::vector<int>>("axis"); auto axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = axis.size();
switch (ndims) { switch (ndims) {
case 2: case 2:
...@@ -107,7 +107,7 @@ class TransposeGradKernel : public framework::OpKernel { ...@@ -107,7 +107,7 @@ class TransposeGradKernel : public framework::OpKernel {
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto axis_temp = context.GetAttr<std::vector<int>>("axis"); auto axis_temp = context.Attr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp); std::vector<int> axis(axis_temp);
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册