1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 2906d0acd9..33610c65f7 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -838,6 +838,24 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
at::Half** C,
CUDAContext* context,
TensorProto::DataType math_type) {
+#if defined(USE_ROCM)
+ // loop over matrices in the batch
+ for (int i = 0; i < batch_size; ++i) {
+ Gemm<at::Half, CUDAContext>(
+ trans_A,
+ trans_B,
+ M,
+ N,
+ K,
+ alpha,
+ A[i],
+ B[i],
+ beta,
+ C[i],
+ context,
+ math_type);
+ }
+#else
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
const int lda = (trans_A == CblasNoTrans) ? K : M;
@@ -912,6 +930,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
} else {
CAFFE_THROW("Unsupported math type");
}
+#endif
}
|