In PyTorch, torch.dot() does not broadcast because it is strictly designed to compute the dot product of two 1D tensors (vectors) of the same number of elements.
If you pass multi-dimensional tensors (like matrices or batches) to torch.dot(), PyTorch will throw a RuntimeError.
🛠️ The Solution: What to Use Instead
To perform matrix multiplication with automatic broadcasting, you should use alternative PyTorch functions depending on your specific use case:
1. Use torch.matmul or the @ Operator (Recommended)
This is the closest equivalent to NumPy's np.dot. It supports broadcasting completely across batch dimensions.
Best for: Standard matrix multiplication, vector-matrix products, and batched operations.
python
import torch
# Batch of 10 matrices (10 x 3 x 4) and a matrix (4 x 5)
A = torch.randn(10, 3, 4)
B = torch.randn(4, 5)
# B is automatically broadcasted to match A's batch size
result = torch.matmul(A, B) # Shape: [10, 3, 5]
# OR using the operator
result = A @ B # Shape: [10, 3, 5]
2. Use torch.mm
This multiplies exactly two 2D matrices. It does not broadcast.
Best for: Strict 2D matrix multiplication where you want an error if dimensions don't align perfectly.
python
A = torch.randn(3, 4)
B = torch.randn(4, 5)
result = torch.mm(A, B) # Shape: [3, 5]
3. Use torch.bmm
This performs batch matrix multiplication. Both tensors must be 3D, and their batch sizes must match exactly. It does not broadcast.
Best for: Explicitly controlled batch matrix multiplications.
python
A = torch.randn(10, 3, 4)
B = torch.randn(10, 4, 5)
result = torch.bmm(A, B) # Shape: [10, 3, 5]
4. Use Element-wise Multiplication * with .sum()
If you want a traditional dot product behavior (multiply matching elements and sum them up) over a specific dimension of a broadcasted tensor, combine the * operator with .sum().
Best for: Custom element-wise operations before reducing.
python
A = torch.randn(10, 3)
B = torch.randn(1, 3) # Broadcasts along the batch dimension (1 -> 10)
# Multiply element-wise (broadcasts) and sum over the last dimension
result = (A * B).sum(dim=-1) # Shape: [10]
📊 Quick Comparison Summary
| Function / Operator | Input Dimensions Allowed | Supports Broadcasting? | Primary Use Case |
|---|---|---|---|
| torch.dot | Strictly 1D and 1D | ❌ No | Basic vector-vector dot product |
| torch.mm | Strictly 2D and 2D | ❌ No | Standard 2D matrix multiplication |
| torch.bmm | Strictly 3D and 3D | ❌ No | Strict batch matrix multiplication |
| torch.matmul / @ | Any dimensions | Yes | Flexible, broadcast-safe multiplication |
No comments:
Post a Comment