summarylogtreecommitdiffstats
path: root/rocblas-batched.patch
blob: 1eef5458bc443efe7fc92a667aa73dcd1527ef3e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 2906d0acd9..33610c65f7 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -838,6 +838,24 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
     at::Half** C,
     CUDAContext* context,
     TensorProto::DataType math_type) {
+#if defined(USE_ROCM)
+  // loop over matrices in the batch
+  for (int i = 0; i < batch_size; ++i) {
+   Gemm<at::Half, CUDAContext>(
+        trans_A,
+        trans_B,
+        M,
+        N,
+        K,
+        alpha,
+        A[i],
+        B[i],
+        beta,
+        C[i],
+        context,
+        math_type);
+  }
+#else
   // Note that cublas follows fortran order, so the order is different from
   // the cblas convention.
   const int lda = (trans_A == CblasNoTrans) ? K : M;
@@ -912,6 +930,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
   } else {
     CAFFE_THROW("Unsupported math type");
   }
+#endif
 }