diff --git a/spec/API_specification/array_api/linalg.py b/spec/API_specification/array_api/linalg.py
index 5336d93c6..43cf983e8 100644
--- a/spec/API_specification/array_api/linalg.py
+++ b/spec/API_specification/array_api/linalg.py
@@ -1,4 +1,4 @@
-from ._types import Literal, Optional, Tuple, Union, Sequence, array
+from ._types import Literal, Optional, Tuple, Union, Sequence, array, dtype
 from .constants import inf
 
 def cholesky(x: array, /, *, upper: bool = False) -> array:
@@ -437,10 +437,20 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
     Alias for :func:`~array_api.tensordot`.
     """
 
-def trace(x: array, /, *, offset: int = 0) -> array:
+def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> array:
     """
     Returns the sum along the specified diagonals of a matrix (or a stack of matrices) ``x``.
 
+    **Special Cases**
+
+    Let ``N`` equal the number of elements over which to compute the sum.
+
+    -   If ``N`` is ``0``, the sum is ``0`` (i.e., the empty sum).
+
+    For floating-point operands,
+
+    -   If ``x_i`` is ``NaN``, the sum is ``NaN`` (i.e., ``NaN`` values propagate).
+
     Parameters
     ----------
     x: array
@@ -453,6 +463,18 @@ def trace(x: array, /, *, offset: int = 0) -> array:
         -   ``offset < 0``: off-diagonal below the main diagonal.
 
         Default: ``0``.
+    dtype: Optional[dtype]
+        data type of the returned array. If ``None``,
+
+        -   if the default data type corresponding to the data type "kind" (integer or floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
+        -   if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
+        -   if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
+        -   if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
+
+        If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
+
+        .. note::
+           keyword argument is intended to help prevent data type overflows.
 
     Returns
     -------
@@ -463,7 +485,7 @@ def trace(x: array, /, *, offset: int = 0) -> array:
 
           out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :])
 
-        The returned array must have the same data type as ``x``.
+        The returned array must have a data type as described by the ``dtype`` parameter above.
     """
 
 def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array: