Call tree processing

What is a call tree?

Call trees are what siuba uses to take what users say they want to do, and convert it into an action, such as…

  • a SQL statement
  • a set of pandas operations

Below is an example expression, alongside with a graphical representation of that expression. This graphical representation is the call tree.

from siuba import _ 

_.hp + _.hp.rank()
█─+
├─█─.
│ ├─_
│ └─'hp'
└─█─'__call__'
  └─█─.
    ├─█─.
    │ ├─_
    │ └─'hp'
    └─'rank'

One thing that often catches people by surprise with call trees, is that calls for an expression like

_.hp.rank()

are not in order from left to right, but the other way around. Looking at its tree..

_.hp.rank()
█─'__call__'
└─█─.
  ├─█─.
  │ ├─_
  │ └─'hp'
  └─'rank'

It goes

  • call _.hp.rank
  • get attribute rank from _.hp
  • get attribute hp from _

I’ll call this order, the entering order. It occurs when we walk down the tree depth first.

Sometimes this order is useful, but often we’ll want to think of the operations in reverse (e.g. closer to how we read them). In order to allow for both situations, in siuba I often use what I’ll refer to as a tree listener. This is a concept borrowed from the Antlr4 parser generator language.

What is a tree listener?

For each node (black box) on a tree, a tree listener allows to to define some custom processing, by specifying enter and exit methods.

_.hp + _.hp.rank()
█─+
├─█─.
│ ├─_
│ └─'hp'
└─█─'__call__'
  └─█─.
    ├─█─.
    │ ├─_
    │ └─'hp'
    └─'rank'

Note that nodes like + and . in the graph above are shorthand for their python method names, __add__ and __getattr__ respectively.

Simple exit method on a tree listener

Below is an example tree listener that strips out a __getattr__ operation from a call.

from siuba.siu import Call, BinaryOp, strip_symbolic
from siuba.siu.visitors import CallListener
from siuba import _ 

class AttrStripper(CallListener):
    def __init__(self, rm_attr):
        self.rm_attr = rm_attr
        
    def exit___getattr__(self, node):
        obj, attr_name = node.args
        if attr_name in self.rm_attr:
            return obj
        
        return node


attr_strip = AttrStripper({'hp'})

call = strip_symbolic(_.hp + _.hp.rank())

print(call)
print(attr_strip.enter(call))            
_.hp + _.hp.rank()
_ + _.rank()

Simple enter method on a tree listener

class AttrStopper(AttrStripper):
    def enter___getattr__(self, node):
        obj, attr_name = node.args
        if attr_name == "stop":
            # don't enter child nodes
            return self.exit(node)
        
        # use generic entering method on this node (and its children)
        return self.generic_enter(node)


attr_stopper = AttrStopper({'hp'})

call = strip_symbolic(_.hp + _.stop.hp + _.hp.stop)

print(call)
print(attr_stopper.enter(call))
_.hp + _.stop.hp + _.hp.stop
_ + _.stop + _.hp.stop

Use enter to “look back on the python execution timeline”

In general, it’s better to use enter when you need to use info that python would execute earlier in time.

Some useful cases include

  • stopping further processing (by not entering child nodes)
  • modifying a child node, prior to entering (i.e. starting processing)

For example, suppose we want to treat method calls in a special way. In _.rank(), we first enter the __call__ node. Moreover, we can lookup whether it is actually a method call from this node, using the rule…

  • if a call node is operating on a get attribute, then it is a method call

This is shown below…

_.dt.year
█─.
├─█─.
│ ├─_
│ └─'dt'
└─'year'
_.dt.year()
█─'__call__'
└─█─.
  ├─█─.
  │ ├─_
  │ └─'dt'
  └─'year'
# want to remove dt
# also want to treat an attribute after dt as a call
# _.a.dt.year
# 
# if we cut dt out in the exit, can't know year is attribute

# TODO: shouldn't need to import BinaryOp, but it helps with formatting
# maybe need factory function?
def is_op(node, opname):
    if isinstance(node, Call) and node.func == opname:
        return True
    
    return False

class MethodMaker(CallListener):
    def enter___getattr__(self, node):
        obj, attr_name = node.args
        
        print("Entering attribute: ", attr_name)
        
        # is to the right of another attribute call
        # e.g. _.<left_attr>.<attr_name>
        if is_op(obj, "__getattr__"):
            left_obj, left_attr = obj.args
            
            print("  Detected attribute chain: ", left_attr, attr_name)
            
            # if the left attr is dt, treat this like a method call
            # e.g. _.dt.year
            if left_attr == "dt":
                # manually enter child nodes, now that we have all the information
                # we need about them
                args, kwargs = node.map_subcalls(self.enter)
                new_obj = node.__class__("__getattr__", *args, **kwargs)
                # since it follows dt, put inside a call op
                method_call = Call("__call__", new_obj)
                return self.exit(method_call)
        
        # otherwise, use default behavior
        return self.generic_enter(node)

    def exit___getattr__(self, node):
        obj, attr_name = node.args
        
        print("Exiting attribute: ", attr_name)
        
        return node
        

method_maker = MethodMaker()

call = strip_symbolic(_.dt.year)

print("Call: ", call)
method_maker.enter(call)
Call:  _.dt.year
Entering attribute:  year
  Detected attribute chain:  dt year
Entering attribute:  dt
Exiting attribute:  dt
_.dt.year()

One limitation of this approach is that if we have an expression like…

_.dt.year()

We’ll still convert the year attribute to a call, causing us to call it twice!

method_maker.enter(strip_symbolic(_.dt.year()))
Entering attribute:  year
  Detected attribute chain:  dt year
Entering attribute:  dt
Exiting attribute:  dt
_.dt.year()()

To get around this, we can extend MethodShouter to check whether an attribute has converted to a call

class MethodMaker2(MethodMaker):
    def enter___call__(self, node):
        # needs to use an enter call, since need to know
        #   * what child was before entering

        obj = node.args[0]
        # don't want to return multiple calls,
        # e.g. _.dt.year() shouldn't produce _.dt.year()()
        if is_op(obj, "__getattr__"):
            args, kwargs = node.map_subcalls(self.enter)

            new_obj, *func_args = args
            
            # getattr transformed itself into a call node, but we're already
            # calling, so peel off the call node it produced...
            if is_op(new_obj, "__call__"):
                new_call = Call("__call__", new_obj.args[0], *func_args, **kwargs)
                return self.exit(new_call)
        
        return self.generic_enter(node)

    def exit___call__(self, node):
        obj = node.args[0]
        if is_op(obj, "__getattr__"):
            left_obj, left_attr = obj.args
            print("Exiting method call: ", left_attr)
        
        return node
    
        

method_maker2 = MethodMaker2()

method_maker2.enter(strip_symbolic(_.dt.year()))
Entering attribute:  year
  Detected attribute chain:  dt year
Entering attribute:  dt
Exiting attribute:  dt
Exiting method call:  year
Exiting method call:  year
_.dt.year()

Keep in mind that using an enter method for an operator can do whatever an exit method for that operator could (and more!). However, there are two important caveats to keep in mind it usually requires more code, since it also needs to enter child nodes.

We can think of the order of enter and exit operations as a big sandwich, where exit is the last step an enter “block” takes. So if the exit doesn’t handle things, the enter can.

_.hp + _.hp.rank()

enter +(_.hp, _.hp.rank())
  enter .(_, "hp")
  exit
  enter __call__(_.hp.rank)
    enter .(_.hp, "rank")
      enter .(_, "hp")
      exit
    exit
  exit
exit
    

In this sense exit is best for actions that can happen after all other processing for a node has happened.

Use exit for simple actions after child nodes are processed

To show where an exit is useful–let’s take the extra step of cutting out dt attributes. To do this, we can override our current getattr exit method (which is only a print statement right now).

class MethodMaker3(MethodMaker2):
    def exit___getattr__(self, node):
        obj, attr_name = node.args
        
        print("Exiting attribute: ", attr_name)
        
        if attr_name == "dt":
            # cut out the dt node
            return obj
        
        return node
    
method_maker3 = MethodMaker3()
# before
method_maker2.enter(call)
Entering attribute:  year
  Detected attribute chain:  dt year
Entering attribute:  dt
Exiting attribute:  dt
Exiting method call:  year
_.dt.year()
# after
method_maker3.enter(call)
Entering attribute:  year
  Detected attribute chain:  dt year
Entering attribute:  dt
Exiting attribute:  dt
Exiting method call:  year
_.year()

Finally, it’s worth asking what will happen with the following call…

call3 = strip_symbolic(_.dt.dt())
call3
_.dt.dt()
method_maker3.enter(call3)
Entering attribute:  dt
  Detected attribute chain:  dt dt
Entering attribute:  dt
Exiting attribute:  dt
Exiting method call:  dt
Exiting method call:  dt
_.dt()

Notice that there are two dt attributes, and they were both entered, but only one exited.

Why is this? To find the answer, you need to look at the enter___getattr__ method of the original MethodMaker class. More specifically, why doesn’t it exit the node its processing, when it creates a new Call node?