Optimum documentation
Comparing HPU-Optimized safe_softmax with Native PyTorch safe_softmax
Comparing HPU-Optimized safe_softmax with Native PyTorch safe_softmax
This article demonstrates how to benchmark and compare the performance of the Habana Processing Unit (HPU)-optimized safe_softmax
operation against the native PyTorch implementation. The provided Python script guides you through the process step-by-step, with detailed explanations for each part. Additionally, we will provide some context about safe_softmax
, its purpose, and its use cases.
Important Note: No Special Setup Required
The safe_softmax
operation works out-of-the-box in PyTorch. When running your code on Habana hardware, the HPU-optimized implementation is automatically utilized without any additional configuration. This seamless integration allows you to benefit from performance improvements without modifying your existing code.
What is safe_softmax ?
The softmax
function is a common operation in machine learning, particularly in classification tasks. It converts raw logits into probabilities by applying the exponential function and normalizing the results. However, the standard softmax
can encounter numerical instability when dealing with very large or very small values in the input tensor, leading to overflow or underflow issues.
To address this, safe_softmax
is implemented. It stabilizes the computation by subtracting the maximum value in each row (or along the specified dimension) from the logits before applying the exponential function. This ensures that the largest value in the exponent is zero, preventing overflow.
Why is safe_softmax important?
- Numerical Stability: Prevents overflow/underflow issues during computation.
- Widely Used: Commonly used in neural networks, especially in the final layer for classification tasks.
- Efficiency: Optimized implementations can significantly improve performance on specialized hardware like GPUs or HPUs.
Step-by-Step Explanation of the Code
1. Importing Required Libraries
import torch
import timeit
import habana_frameworks.torch as ht
from torch._decomp.decompositions import safe_softmax as native_safe_softmax
torch
: The core PyTorch library for tensor operations.timeit
: A Python module for measuring execution time.habana_frameworks.torch
: Provides support for Habana hardware (HPUs).safe_softmax
: The native PyTorch implementation ofsafe_softmax
is imported for comparison.
2. Defining the HPU-Optimized safe_softmax
hpu_safe_softmax = torch.ops.aten._safe_softmax.default
- The HPU-optimized version of
safe_softmax
is accessed via thetorch.ops.aten
namespace. This implementation is specifically designed to leverage the Habana hardware for faster execution.
3. Preparing the Input Tensor
input_tensor = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]).to("hpu")
- A 2D tensor is created with some typical values, including
-inf
to simulate edge cases. - The tensor is moved to the HPU device using
.to("hpu")
.
4. Warmup for Fair Benchmarking
hpu_safe_softmax(input_tensor, dim=1); ht.hpu.synchronize()
native_safe_softmax(input_tensor, dim=1); ht.hpu.synchronize()
- Both the HPU-optimized and native implementations are executed once before benchmarking. This ensures that any initialization overhead is excluded from the timing measurements.
ht.hpu.synchronize()
ensures that all HPU operations are completed before proceeding.
5. Benchmarking the Implementations
num_iterations = 10000
hpu_time = timeit.timeit(
"hpu_safe_softmax(input_tensor, dim=1); ht.hpu.synchronize()",
globals=globals(),
number=num_iterations
)
native_time = timeit.timeit(
"native_safe_softmax(input_tensor, dim=1); ht.hpu.synchronize()",
globals=globals(),
number=num_iterations
)
- The
timeit
module is used to measure the execution time of each implementation over 10,000 iterations. - The
globals=globals()
argument allows thetimeit
module to access the defined variables and functions in the script.
6. Printing the Results
print(f"Performance comparison over {num_iterations} iterations:")
print(f"Native safe_softmax: {native_time:.6f} seconds")
print(f"HPU safe_softmax: {hpu_time:.6f} seconds")
- The execution times for both implementations are printed, allowing for a direct comparison of their performance.
Example Output
After running the script, you might see output similar to the following (lower is better):
Performance comparison over 10000 iterations:
Native safe_softmax: 1.004057 seconds
HPU safe_softmax: 0.104004 seconds