tf.keras.models.clone_model
Stay organized with collections
Save and categorize content based on your preferences.
Clone a Functional or Sequential Model
instance.
tf.keras.models.clone_model(
model,
input_tensors=None,
clone_function=None,
call_function=None,
recursive=False,
**kwargs
)
Used in the notebooks
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
Note that
clone_model
will not preserve the uniqueness of shared objects within the
model (e.g. a single variable attached to two distinct layers will be
restored as two separate variables).
Args |
model
|
Instance of Model
(could be a Functional model or a Sequential model).
|
input_tensors
|
optional list of input tensors or InputLayer objects
to build the model upon. If not provided,
new Input objects will be created.
|
clone_function
|
Callable with signature fn(layer)
to be used to clone each layer in the target
model (except Input instances). It takes as argument the
layer instance to be cloned, and returns the corresponding layer
instance to be used in the model copy. If unspecified, this callable
defaults to the following serialization/deserialization function:
lambda layer: layer.__class__.from_config(layer.get_config()) .
By passing a custom callable, you can customize your copy of the
model, e.g. by wrapping certain layers of interest (you might want
to replace all LSTM instances with equivalent
Bidirectional(LSTM(...)) instances, for example).
Defaults to None .
|
call_function
|
Callable with signature
fn(layer, *args, **kwargs) to be used to call each
cloned layer and a set of inputs. It takes the layer instance,
the call arguments and keyword arguments, and returns the
call outputs. If unspecified, this callable defaults to
the regular __call__() method:
def fn(layer, *args, **kwargs): return layer(*args, **kwargs) .
By passing a custom callable, you can insert new layers before or
after a given layer. Note: this argument can only be used with
Functional models.
|
recursive
|
Boolean. Whether to recursively clone any Sequential
or Functional models encountered in the original
Sequential/Functional model. If False ,
then inner models are cloned by calling clone_function() .
If True , then inner models are cloned by calling clone_model()
with the same clone_function , call_function , and recursive
arguments. Note that in this case, call_function
will not be propagated to any Sequential model
(since it is not applicable to Sequential models).
|
Returns |
An instance of Model reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights. The cloned model may behave
differently from the original model if a custom clone_function
or call_function modifies a layer or layer call.
|
Example:
# Create a test Sequential model.
model = keras.Sequential([
keras.layers.Input(shape=(728,)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
# Create a copy of the test model (with freshly initialized weights).
new_model = clone_model(model)
Using a clone_function
to make a model deterministic by setting the
random seed everywhere:
def clone_function(layer):
config = layer.get_config()
if "seed" in config:
config["seed"] = 1337
return layer.__class__.from_config(config)
new_model = clone_model(model)
Using a call_function
to add a Dropout
layer after each Dense
layer
(without recreating new layers):
def call_function(layer, *args, **kwargs):
out = layer(*args, **kwargs)
if isinstance(layer, keras.layers.Dense):
out = keras.layers.Dropout(0.5)(out)
return out
new_model = clone_model(
model,
clone_function=lambda x: x, # Reuse the same layers.
call_function=call_function,
)
Note that subclassed models cannot be cloned by default,
since their internal layer structure is not known.
To achieve equivalent functionality
as clone_model
in the case of a subclassed model, simply make sure
that the model class implements get_config()
(and optionally from_config()
), and call:
new_model = model.__class__.from_config(model.get_config())
In the case of a subclassed model, you cannot using a custom
clone_function
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-06-07 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-06-07 UTC."],[],[],null,["# tf.keras.models.clone_model\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/models/cloning.py#L13-L209) |\n\nClone a Functional or Sequential `Model` instance. \n\n tf.keras.models.clone_model(\n model,\n input_tensors=None,\n clone_function=None,\n call_function=None,\n recursive=False,\n **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Scalable model compression](https://www.tensorflow.org/tutorials/optimization/compression) - [Federated Learning for Text Generation](https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation) |\n\nModel cloning is similar to calling a model on new inputs,\nexcept that it creates new layers (and thus new weights) instead\nof sharing the weights of the existing layers.\n\nNote that\n`clone_model` will not preserve the uniqueness of shared objects within the\nmodel (e.g. a single variable attached to two distinct layers will be\nrestored as two separate variables).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `model` | Instance of `Model` (could be a Functional model or a Sequential model). |\n| `input_tensors` | optional list of input tensors or InputLayer objects to build the model upon. If not provided, new `Input` objects will be created. |\n| `clone_function` | Callable with signature `fn(layer)` to be used to clone each layer in the target model (except `Input` instances). It takes as argument the layer instance to be cloned, and returns the corresponding layer instance to be used in the model copy. If unspecified, this callable defaults to the following serialization/deserialization function: `lambda layer: layer.__class__.from_config(layer.get_config())`. By passing a custom callable, you can customize your copy of the model, e.g. by wrapping certain layers of interest (you might want to replace all `LSTM` instances with equivalent `Bidirectional(LSTM(...))` instances, for example). Defaults to `None`. |\n| `call_function` | Callable with signature `fn(layer, *args, **kwargs)` to be used to call each cloned layer and a set of inputs. It takes the layer instance, the call arguments and keyword arguments, and returns the call outputs. If unspecified, this callable defaults to the regular `__call__()` method: `def fn(layer, *args, **kwargs): return layer(*args, **kwargs)`. By passing a custom callable, you can insert new layers before or after a given layer. Note: this argument can only be used with Functional models. |\n| `recursive` | Boolean. Whether to recursively clone any Sequential or Functional models encountered in the original Sequential/Functional model. If `False`, then inner models are cloned by calling `clone_function()`. If `True`, then inner models are cloned by calling `clone_model()` with the same `clone_function`, `call_function`, and `recursive` arguments. Note that in this case, `call_function` will not be propagated to any Sequential model (since it is not applicable to Sequential models). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| An instance of `Model` reproducing the behavior of the original model, on top of new inputs tensors, using newly instantiated weights. The cloned model may behave differently from the original model if a custom `clone_function` or `call_function` modifies a layer or layer call. ||\n\n\u003cbr /\u003e\n\n#### Example:\n\n # Create a test Sequential model.\n model = keras.Sequential([\n keras.layers.Input(shape=(728,)),\n keras.layers.Dense(32, activation='relu'),\n keras.layers.Dense(1, activation='sigmoid'),\n ])\n # Create a copy of the test model (with freshly initialized weights).\n new_model = clone_model(model)\n\nUsing a `clone_function` to make a model deterministic by setting the\nrandom seed everywhere: \n\n def clone_function(layer):\n config = layer.get_config()\n if \"seed\" in config:\n config[\"seed\"] = 1337\n return layer.__class__.from_config(config)\n\n new_model = clone_model(model)\n\nUsing a `call_function` to add a `Dropout` layer after each `Dense` layer\n(without recreating new layers): \n\n def call_function(layer, *args, **kwargs):\n out = layer(*args, **kwargs)\n if isinstance(layer, keras.layers.Dense):\n out = keras.layers.Dropout(0.5)(out)\n return out\n\n new_model = clone_model(\n model,\n clone_function=lambda x: x, # Reuse the same layers.\n call_function=call_function,\n )\n\nNote that subclassed models cannot be cloned by default,\nsince their internal layer structure is not known.\nTo achieve equivalent functionality\nas `clone_model` in the case of a subclassed model, simply make sure\nthat the model class implements `get_config()`\n(and optionally `from_config()`), and call: \n\n new_model = model.__class__.from_config(model.get_config())\n\nIn the case of a subclassed model, you cannot using a custom\n`clone_function`."]]