Coverage for wsimod\extensions.py: 29%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-30 14:52 +0000

1"""This module contains the utilities to extend WSMOD with new features. 

2 

3The `register_node_patch` decorator is used to register a function that will be used 

4instead of a method or attribute of a node. The `apply_patches` function applies all 

5registered patches to a model. 

6 

7Example of patching a method: 

8 

9`empty_distributed` will be called instead of `pull_distributed` of "my_node": 

10 

11 >>> from wsimod.extensions import register_node_patch, apply_patches 

12 >>> @register_node_patch("my_node", "pull_distributed") 

13 >>> def empty_distributed(self, vqip): 

14 >>> return {} 

15 

16Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a 

17list or a dictionary can be patched if the item argument is provided. 

18 

19Example of patching an attribute: 

20 

21`10` will be assigned to `t`: 

22 

23 >>> @register_node_patch("my_node", "t", is_attr=True) 

24 >>> def patch_t(node): 

25 >>> return 10 

26 

27Example of patching an attribute item: 

28 

29`patch_default_pull_set_handler` will be assigned to 

30`pull_set_handler["default"]`: 

31 

32 >>> @register_node_patch("my_node", "pull_set_handler", item="default") 

33 >>> def patch_default_pull_set_handler(self, vqip): 

34 >>> return {} 

35 

36If patching a method of an attribute, the `is_attr` argument should be set to `True` and 

37the target should include the attribute name and the method name, all separated by  

38periods, eg. `attribute_name.method_name`. 

39 

40It should be noted that the patched function should have the same signature as the 

41original method or attribute, and the return type should be the same as well, otherwise 

42there will be a runtime error. In particular, the first argument of the patched function 

43should be the node object itself, which will typically be named `self`. 

44 

45The overridden method or attribute can be accessed within the patched function using the 

46`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`.  

47The exception to this is when patching an item, in which case the original item is no 

48available to be used within the overriding function. 

49 

50Finally, the `apply_patches` is called within the `Model.load` method and will apply all 

51patches in the order they were registered. This means that users need to be careful with 

52the order of the patches in their extensions files, as they may have interdependencies. 

53 

54TODO: Update documentation on extensions files. 

55""" 

56 

57import warnings 

58from typing import Callable, Hashable 

59 

60from .orchestration.model import Model 

61 

62extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} 

63 

64 

65def register_node_patch( 

66 node_name: str, target: str, item: Hashable = None, is_attr: bool = False 

67) -> Callable: 

68 """Register a function to patch a node method or any of its attributes. 

69 

70 Args: 

71 node_name (str): The name of the node to patch. 

72 target (str): The target of the object to patch in the form of a string with the 

73 attribute, sub-attribute, etc. and finally method (or attribute) to replace, 

74 sepparated with period, eg. `make_discharge` or 

75 `sewer_tank.pull_storage_exact`. 

76 item (Hashable): Typically a string or an integer indicating the item to replace 

77 in the selected attribue, which should be a list or a dictionary. 

78 is_attr (bool): If True, the decorated function will be called when applying 

79 the patch and the result assigned to the target, instead of assigning the 

80 function itself. In this case, the only argument passed to the function is 

81 the node object. 

82 """ 

83 target_id = (node_name, target, item, is_attr) 

84 if target_id in extensions_registry: 

85 warnings.warn(f"Patch for {target} already registered.") 

86 

87 def decorator(func): 

88 extensions_registry[target_id] = func 

89 return func 

90 

91 return decorator 

92 

93 

94def apply_patches(model: Model) -> None: 

95 """Apply all registered patches to the model. 

96 

97 TODO: Validate signature of the patched methods and type of patched attributes. 

98 

99 Args: 

100 model (Model): The model to apply the patches to. 

101 """ 

102 for (node_name, target, item, is_attr), func in extensions_registry.items(): 

103 starget = target.split(".") 

104 method = starget.pop() 

105 

106 # Get the member to patch 

107 node = obj = model.nodes[node_name] 

108 for attr in starget: 

109 obj = getattr(obj, attr) 

110 

111 # Apply the patch 

112 if item is not None: 

113 obj = getattr(obj, method) 

114 obj[item] = func(node) if is_attr else func.__get__(node, node.__class__) 

115 else: 

116 setattr(obj, f"_patched_{method}", getattr(obj, method)) 

117 setattr( 

118 obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__) 

119 )