Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
04a35150
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04a35150
编写于
11月 09, 2017
作者:
Y
yangyaming
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove MulValu* and reduce time cost for unit test.
上级
07f3f07f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
3 addition
and
531 deletion
+3
-531
paddle/function/MulValueOp.cpp
paddle/function/MulValueOp.cpp
+0
-155
paddle/function/MulValueOp.h
paddle/function/MulValueOp.h
+0
-55
paddle/function/MulValueOpGpu.cu
paddle/function/MulValueOpGpu.cu
+0
-116
paddle/function/MulValueOpTest.cpp
paddle/function/MulValueOpTest.cpp
+0
-75
paddle/function/ScaleSubRegionOpTest.cpp
paddle/function/ScaleSubRegionOpTest.cpp
+3
-3
paddle/gserver/layers/MulValueLayer.cpp
paddle/gserver/layers/MulValueLayer.cpp
+0
-75
paddle/gserver/layers/MulValueLayer.h
paddle/gserver/layers/MulValueLayer.h
+0
-52
未找到文件。
paddle/function/MulValueOp.cpp
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueOp.h"
#include "paddle/function/TensorShape.h"
namespace
paddle
{
template
<
>
void
MulValue
<
DEVICE_TYPE_CPU
>
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
memcpy
(
outputs
,
inputs
,
number
*
channel
*
height
*
width
*
sizeof
(
real
));
for
(
int
n
=
0
;
n
<
number
;
++
n
)
{
// indices start from 1
int
offset
=
n
*
6
;
for
(
int
c
=
indices
[
offset
]
-
1
;
c
<
indices
[
offset
+
1
];
++
c
)
{
for
(
int
h
=
indices
[
offset
+
2
]
-
1
;
h
<
indices
[
offset
+
3
];
++
h
)
{
for
(
int
w
=
indices
[
offset
+
4
]
-
1
;
w
<
indices
[
offset
+
5
];
++
w
)
{
int
idx
=
((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
;
outputs
[
idx
]
*=
value
;
}
}
}
}
}
template
<
>
void
MulValueGrad
<
DEVICE_TYPE_CPU
>
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
for
(
int
n
=
0
;
n
<
number
;
++
n
)
{
for
(
int
c
=
0
;
c
<
channel
;
++
c
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
int
idx
=
((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
;
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outGrad
[
idx
]
+=
inGrad
[
idx
]
*
value
;
}
else
{
outGrad
[
idx
]
+=
inGrad
[
idx
];
}
}
}
}
}
}
/**
* \brief For each instance, MulValue can be used to multiply a value to a
* specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the region.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], only one input.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with same shape as inputs, output value.
*/
template
<
DeviceType
Device
>
class
MulValueFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
conf_
=
config
;
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
2UL
,
inputs
.
size
());
CHECK_EQ
(
1UL
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ASSIGN_TO
);
TensorShape
shape
=
inputs
[
0
].
shape
();
MulValue
<
Device
>
(
outputs
[
0
].
data
<
real
>
(),
inputs
[
0
].
data
<
real
>
(),
inputs
[
1
].
data
<
real
>
(),
shape
,
conf_
);
}
private:
FuncConfig
conf_
;
};
/**
* \brief The backward propagation of MulValue Function.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], output gradient.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
*/
template
<
DeviceType
Device
>
class
MulValueGradFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
conf_
=
config
;
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
2UL
,
inputs
.
size
());
CHECK_EQ
(
1UL
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
TensorShape
shape
=
inputs
[
0
].
shape
();
MulValueGrad
<
Device
>
(
inputs
[
0
].
data
<
real
>
(),
outputs
[
0
].
data
<
real
>
(),
inputs
[
1
].
data
<
real
>
(),
shape
,
conf_
);
}
private:
FuncConfig
conf_
;
};
REGISTER_TYPED_FUNC
(
MulValue
,
CPU
,
MulValueFunc
);
REGISTER_TYPED_FUNC
(
MulValueGrad
,
CPU
,
MulValueGradFunc
);
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC
(
MulValue
,
GPU
,
MulValueFunc
);
REGISTER_TYPED_FUNC
(
MulValueGrad
,
GPU
,
MulValueGradFunc
);
#endif
}
// namespace paddle
paddle/function/MulValueOp.h
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace
paddle
{
/**
* \brief Function to multiply a value to values in specified sub continuous
* region. Indices must be provided to indcate the location and shape of
* the region and the multiplied value is passed by configure variable.
*
*
* \param[out] outputs Output value.
* \param[in] inputs Input data which contains NCHW information.
* \param[in] indices Indices data to indcate the sub region.
* \param[in] shape Tensor shape of input value.
* \param[in] conf Configure variable which contains the multiplied value.
*/
template
<
DeviceType
Device
>
void
MulValue
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
);
/**
* \brief Back propagation function of MulValue.
*
* \param[out] inGrad Gradients of previous layer.
* \param[in] outGrad Output gradient.
* \param[in] indices Indices data.
* \param[in] shape The Shape of input tensor.
* \param[in] conf Configure variable.
*/
template
<
DeviceType
Device
>
void
MulValueGrad
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
);
}
// namespace paddle
paddle/function/MulValueOpGpu.cu
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueOp.h"
#include "hl_base.h"
namespace
paddle
{
__global__
void
KeMulValue
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
real
value
,
int
channel
,
int
height
,
int
width
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
c
=
(
idx
/
width
/
height
)
%
channel
;
const
int
n
=
idx
/
width
/
height
/
channel
;
const
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outputs
[
idx
]
=
inputs
[
idx
]
*
value
;
}
else
{
outputs
[
idx
]
=
inputs
[
idx
];
}
}
}
template
<
>
void
MulValue
<
DEVICE_TYPE_GPU
>
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
size_t
nth
=
number
*
channel
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
nth
+
blockSize
-
1
)
/
blockSize
;
KeMulValue
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
outputs
,
inputs
,
indices
,
value
,
channel
,
height
,
width
,
nth
);
CHECK_SYNC
(
"MulValue"
);
}
__global__
void
KeMulValueDiff
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
real
value
,
int
channel
,
int
height
,
int
width
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
c
=
(
idx
/
width
/
height
)
%
channel
;
const
int
n
=
idx
/
width
/
height
/
channel
;
const
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outGrad
[
idx
]
+=
inGrad
[
idx
]
*
value
;
}
else
{
outGrad
[
idx
]
+=
inGrad
[
idx
];
}
}
}
template
<
>
void
MulValueGrad
<
DEVICE_TYPE_GPU
>
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
size_t
nth
=
number
*
channel
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
nth
+
blockSize
-
1
)
/
blockSize
;
KeMulValueDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inGrad
,
outGrad
,
indices
,
value
,
channel
,
height
,
width
,
nth
);
CHECK_SYNC
(
"MulValueGrad"
);
}
}
// namespace paddle
paddle/function/MulValueOpTest.cpp
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace
paddle
{
TEST
(
MulValue
,
real
)
{
for
(
size_t
numSamples
:
{
5
,
32
})
{
for
(
size_t
channels
:
{
5
,
5
,
32
})
{
for
(
size_t
imgSizeH
:
{
5
,
33
,
100
})
{
for
(
size_t
imgSizeW
:
{
5
,
32
,
96
})
{
for
(
real
value
:
{
-
0.5
,
0.0
,
0.5
})
{
for
(
bool
firstHalf
:
{
false
,
true
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" channels="
<<
channels
<<
" imgSizeH="
<<
imgSizeH
<<
" imgSizeW="
<<
imgSizeW
;
for
(
bool
test_grad
:
{
false
})
{
CpuGpuFuncCompare
compare
(
test_grad
?
"MulValueGrad"
:
"MulValue"
,
FuncConfig
().
set
<
real
>
(
"value"
,
value
));
TensorShape
shape
{
numSamples
,
channels
,
imgSizeH
,
imgSizeW
};
TensorShape
indicesShape
{
numSamples
,
6
};
compare
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
shape
));
compare
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
indicesShape
));
compare
.
registerInitCallback
([
=
](
BufferArg
&
arg
,
size_t
index
)
{
if
(
index
==
1
)
{
real
*
data
=
(
real
*
)
arg
.
data
();
for
(
size_t
i
=
0
;
i
<
numSamples
;
++
i
)
{
size_t
offset
=
i
*
6
;
data
[
offset
]
=
firstHalf
?
1
:
(
int
)
channels
/
2
;
data
[
offset
+
1
]
=
firstHalf
?
(
int
)
channels
/
2
:
channels
;
data
[
offset
+
2
]
=
firstHalf
?
1
:
(
int
)
imgSizeH
/
2
;
data
[
offset
+
3
]
=
firstHalf
?
(
int
)
imgSizeH
/
2
:
imgSizeH
;
data
[
offset
+
4
]
=
firstHalf
?
1
:
(
int
)
imgSizeW
/
2
;
data
[
offset
+
5
]
=
firstHalf
?
(
int
)
imgSizeW
/
2
:
imgSizeW
;
}
}
});
compare
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
shape
,
test_grad
?
ADD_TO
:
ASSIGN_TO
),
test_grad
?
ADD_TO
:
ASSIGN_TO
);
compare
.
run
();
}
}
}
}
}
}
}
}
}
// namespace paddle
paddle/function/ScaleSubRegionOpTest.cpp
浏览文件 @
04a35150
...
...
@@ -19,9 +19,9 @@ namespace paddle {
TEST
(
ScaleSubRegion
,
real
)
{
for
(
size_t
numSamples
:
{
5
,
32
})
{
for
(
size_t
channels
:
{
5
,
5
,
32
})
{
for
(
size_t
imgSizeH
:
{
5
,
33
,
100
})
{
for
(
size_t
imgSizeW
:
{
5
,
32
,
96
})
{
for
(
size_t
channels
:
{
5
,
32
})
{
for
(
size_t
imgSizeH
:
{
5
,
33
})
{
for
(
size_t
imgSizeW
:
{
5
,
32
})
{
for
(
real
value
:
{
-
0.5
,
0.0
,
0.5
})
{
for
(
bool
firstHalf
:
{
false
,
true
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
...
...
paddle/gserver/layers/MulValueLayer.cpp
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulValueLayer.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
REGISTER_LAYER
(
mul_value
,
MulValueLayer
);
bool
MulValueLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
static_cast
<
int
>
(
inputLayers_
.
size
()),
2
);
auto
&
conf
=
config_
.
inputs
(
0
).
mul_value_conf
();
value_
=
conf
.
value
();
createFunction
(
forward_
,
"MulValue"
,
FuncConfig
().
set
(
"value"
,
value_
));
createFunction
(
backward_
,
"MulValueGrad"
,
FuncConfig
().
set
(
"value"
,
value_
));
return
true
;
}
void
MulValueLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
auto
in0
=
getInput
(
0
);
imgH_
=
in0
.
getFrameHeight
();
imgW_
=
in0
.
getFrameWidth
();
if
(
imgH_
==
0
||
imgW_
==
0
)
{
auto
&
conf
=
config_
.
inputs
(
0
).
mul_value_conf
();
imgH_
=
conf
.
image_conf
().
img_size_y
();
imgW_
=
conf
.
image_conf
().
img_size
();
}
MatrixPtr
imgV
=
in0
.
value
;
size_t
batchSize
=
imgV
->
getHeight
();
size_t
spatialSize
=
imgH_
*
imgW_
;
channelsNum_
=
imgV
->
getWidth
()
/
spatialSize
;
shape_
=
TensorShape
({
batchSize
,
channelsNum_
,
imgH_
,
imgW_
});
resetOutput
(
batchSize
,
imgV
->
getWidth
());
MatrixPtr
indicesV
=
getInputValue
(
1
);
indicesShape_
=
TensorShape
({
batchSize
,
6
});
REGISTER_TIMER_INFO
(
"MulValueForward"
,
getName
().
c_str
());
BufferArgs
inArgs
;
BufferArgs
outArgs
;
inArgs
.
addArg
(
*
imgV
,
shape_
);
inArgs
.
addArg
(
*
indicesV
,
indicesShape_
);
MatrixPtr
outV
=
getOutputValue
();
outArgs
.
addArg
(
*
outV
,
shape_
,
ASSIGN_TO
);
forward_
[
0
]
->
calc
(
inArgs
,
outArgs
);
}
void
MulValueLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
REGISTER_TIMER_INFO
(
"MulValueBackward"
,
getName
().
c_str
());
BufferArgs
inArgs
;
BufferArgs
outArgs
;
inArgs
.
addArg
(
*
getOutputGrad
(),
shape_
);
inArgs
.
addArg
(
*
getInputValue
(
1
),
indicesShape_
);
outArgs
.
addArg
(
*
getInputGrad
(
0
),
shape_
,
ADD_TO
);
backward_
[
0
]
->
calc
(
inArgs
,
outArgs
);
}
}
// namespace paddle
paddle/gserver/layers/MulValueLayer.h
已删除
100644 → 0
浏览文件 @
07f3f07f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace
paddle
{
/**
* \brief For each instance, this layer can be used to multiply a value to a
* specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the
* region.
*
* input_0: Input value.
* input_1: Indices value to specify the location an shape of the
* region.
*/
class
MulValueLayer
:
public
Layer
{
public:
explicit
MulValueLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
~
MulValueLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
protected:
TensorShape
shape_
;
TensorShape
indicesShape_
;
size_t
imgH_
;
size_t
imgW_
;
size_t
channelsNum_
;
real
value_
;
};
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录