Saturday, May 30, 2026

PyTorch torch.dot() does not broadcast

In PyTorch, torch.dot() does not broadcast because it is strictly designed to compute the dot product of two 1D tensors (vectors) of the same number of elements.

If you pass multi-dimensional tensors (like matrices or batches) to torch.dot(), PyTorch will throw a RuntimeError.

🛠️ The Solution: What to Use Instead

To perform matrix multiplication with automatic broadcasting, you should use alternative PyTorch functions depending on your specific use case:

1. Use torch.matmul or the @ Operator (Recommended)

This is the closest equivalent to NumPy's np.dot. It supports broadcasting completely across batch dimensions.

Best for: Standard matrix multiplication, vector-matrix products, and batched operations.

python
import torch

# Batch of 10 matrices (10 x 3 x 4) and a matrix (4 x 5)
A = torch.randn(10, 3, 4)
B = torch.randn(4, 5)

# B is automatically broadcasted to match A's batch size
result = torch.matmul(A, B)  # Shape: [10, 3, 5]
# OR using the operator
result = A @ B               # Shape: [10, 3, 5]

2. Use torch.mm

This multiplies exactly two 2D matrices. It does not broadcast.

Best for: Strict 2D matrix multiplication where you want an error if dimensions don't align perfectly.

python
A = torch.randn(3, 4)
B = torch.randn(4, 5)
result = torch.mm(A, B)  # Shape: [3, 5]

3. Use torch.bmm

This performs batch matrix multiplication. Both tensors must be 3D, and their batch sizes must match exactly. It does not broadcast.

Best for: Explicitly controlled batch matrix multiplications.

python
A = torch.randn(10, 3, 4)
B = torch.randn(10, 4, 5)
result = torch.bmm(A, B)  # Shape: [10, 3, 5]

4. Use Element-wise Multiplication * with .sum()

If you want a traditional dot product behavior (multiply matching elements and sum them up) over a specific dimension of a broadcasted tensor, combine the * operator with .sum().

Best for: Custom element-wise operations before reducing.

python
A = torch.randn(10, 3)
B = torch.randn(1, 3)  # Broadcasts along the batch dimension (1 -> 10)

# Multiply element-wise (broadcasts) and sum over the last dimension
result = (A * B).sum(dim=-1)  # Shape: [10]

📊 Quick Comparison Summary

Function / Operator Input Dimensions Allowed Supports Broadcasting? Primary Use Case
torch.dot Strictly 1D and 1D ❌ No Basic vector-vector dot product
torch.mm Strictly 2D and 2D ❌ No Standard 2D matrix multiplication
torch.bmm Strictly 3D and 3D ❌ No Strict batch matrix multiplication
torch.matmul / @ Any dimensions Yes Flexible, broadcast-safe multiplication

No comments:

Post a Comment

Machine Learning and AI Model Taxonomy

The following table compares major categories of Machine Learning, Deep Learning, Generative AI, and Reinforcem...