summarylogtreecommitdiffstats
path: root/kernelmatrix_kernel.cu.patch
blob: 368344ab9995adce93253a6d44b78bc7eca0cbdf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
--- test.cpp	2020-12-11 11:33:51.985362605 +0800
+++ kernelmatrix_kernel.cu	2020-12-11 11:32:40.475793777 +0800
@@ -144,6 +144,40 @@ namespace svm_kernel {
         }
         kernel_type one(1);
         kernel_type zero(0);
+
+#if (CUDART_VERSION >= 11000)
+
+        cusparseSpMatDescr_t matA;
+        cusparseDnMatDescr_t matB, matC;
+#ifdef USE_DOUBLE
+        cudaDataType data_type = CUDA_R_64F;
+#else//kernel type is float
+        cudaDataType data_type = CUDA_R_32F;
+#endif
+        cusparseCreateCsr(&matA, m, k, nnz, (void*)csr_row_ptr.device_data(), (void*)csr_col_ind.device_data(),
+                          (void*)csr_val.device_data(), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
+                          CUSPARSE_INDEX_BASE_ZERO, data_type);
+        cusparseCreateDnMat(&matB, n, k, n, (void*)dense_mat.device_data(), data_type, CUSPARSE_ORDER_COL);
+        cusparseCreateDnMat(&matC, m, n, m, (void*)result.device_data(), data_type, CUSPARSE_ORDER_COL);
+
+        size_t buffer_size = 0;
+        cusparseSpMM_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
+                                &one, matA, matB, &zero, matC, data_type, CUSPARSE_CSRMM_ALG1,
+                                &buffer_size);
+
+        void *p_buffer = nullptr;
+        cudaMalloc((void**)&p_buffer, buffer_size);
+
+        cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
+                    &one, matA, matB, &zero, matC, data_type, CUSPARSE_CSRMM_ALG1, p_buffer);
+
+        cudaFree(p_buffer);
+        cusparseDestroySpMat(matA);
+        cusparseDestroyDnMat(matB);
+        cusparseDestroyDnMat(matC);
+
+#else
+
 #ifdef USE_DOUBLE
         cusparseDcsrmm2(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
                         m, n, k, nnz, &one, descr, csr_val.device_data(), csr_row_ptr.device_data(),
@@ -154,9 +188,10 @@ namespace svm_kernel {
                         m, n, k, nnz, &one, descr, csr_val.device_data(), csr_row_ptr.device_data(),
                         csr_col_ind.device_data(),
                         dense_mat.device_data(), n, &zero, result.device_data(), m);
-#endif
-

         //cusparseScsrmm return row-major matrix, so no transpose is needed
+#endif // ifdef USE_DOUBLE
+
+#endif // if CUDART_VERSION >= 11000
     }
 }