未验证 提交 faba4b11 编写于 作者: B Bai Yifan 提交者: GitHub

Remove disable flag in test_fsp_op.py (#22171)

* fix fsp_op, test=develop

* fix fsp grad op maker, test=develop

* update op_use_default_grad_op_maker.spec, test=develop
上级 67e9247f
cos_sim
fsp
gru
match_matrix_tensor
maxout
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fsp_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -114,14 +115,37 @@ class FSPOpGrad : public framework::OperatorWithKernel {
}
};
template <typename T>
class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("fsp_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fsp, ops::FSPOp, ops::FSPOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
ops::FSPGradOpMaker<paddle::framework::OpDesc>,
ops::FSPGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
REGISTER_OP_CPU_KERNEL(
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -46,6 +46,7 @@ class FSPOpKernel : public framework::OpKernel<T> {
x_mat_desc.width_ = height * width;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * height * width;
x_mat_desc.trans_ = false;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = height * width;
......@@ -93,12 +94,14 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
d_out_mat_desc.width_ = y_channel;
d_out_mat_desc.batch_size_ = batch_size;
d_out_mat_desc.stride_ = x_channel * y_channel;
d_out_mat_desc.trans_ = false;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = y_channel;
y_mat_desc.width_ = h * w;
y_mat_desc.batch_size_ = batch_size;
y_mat_desc.stride_ = y_channel * h * w;
y_mat_desc.trans_ = false;
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
......@@ -125,6 +128,7 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
x_mat_desc.width_ = h * w;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * h * w;
x_mat_desc.trans_ = false;
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
......
......@@ -34,7 +34,6 @@ def fsp_matrix(a, b):
return np.mean(a_r * b_r, axis=1)
@unittest.skip("Disable temporarily.")
class TestFSPOp(OpTest):
def setUp(self):
self.op_type = "fsp"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册