Skip to content

Push-down method

Introduction

The module implements push-down method refactoring

Pre-conditions:

Todo: Add pre-conditions

Post-conditions:

Todo: Add post-conditions

DeleteSourceListener (JavaParserLabeledListener)

Source code in codart\refactorings\pushdown_method.py
class DeleteSourceListener(JavaParserLabeledListener):
    """


    """

    def __init__(self, common_token_stream: CommonTokenStream, source_method: str):
        """

        """

        self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
        self.source_method = source_method

    def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if self.source_method == ctx.IDENTIFIER().getText():
            self.token_stream_rewriter.replaceRange(
                from_idx=ctx.parentCtx.parentCtx.start.tokenIndex,
                to_idx=ctx.parentCtx.parentCtx.stop.tokenIndex,
                text=""
            )

__init__(self, common_token_stream, source_method) special

Source code in codart\refactorings\pushdown_method.py
def __init__(self, common_token_stream: CommonTokenStream, source_method: str):
    """

    """

    self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
    self.source_method = source_method

PropagationListener (JavaParserLabeledListener)

Source code in codart\refactorings\pushdown_method.py
class PropagationListener(JavaParserLabeledListener):
    """

    """

    def __init__(self, common_token_stream: CommonTokenStream, source_class: str, child_class: str, class_name: str,
                 method_name: str, ref_line: int, target_package: str):
        """


        """

        self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
        self.source_class = source_class
        self.child_class = child_class
        self.class_name = class_name
        self.method_name = method_name
        self.ref_line = ref_line
        self.target_package = target_package

        self.start = None
        self.stop = None
        self.is_safe = False
        self.need_cast = False
        self.variable = None
        self.detected_class = False
        self.detected_package = False
        self.import_end = None

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        self.is_safe = ctx.IDENTIFIER().getText() == self.class_name

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        self.is_safe = not self.is_safe

    def enterPackageDeclaration(self, ctx: JavaParserLabeled.PackageDeclarationContext):
        if self.target_package in ctx.getText():
            self.detected_package = True
        self.import_end = ctx.stop

    def enterImportDeclaration(self, ctx: JavaParserLabeled.ImportDeclarationContext):
        if f"{self.target_package}.{self.child_class}" in ctx.getText():
            self.detected_package = True
        self.import_end = ctx.stop

    def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext):
        if not self.detected_package and self.import_end is not None:
            self.token_stream_rewriter.insertAfterToken(
                token=self.import_end,
                text=f"\nimport {self.target_package}.{self.child_class};\n",
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME
            )

__init__(self, common_token_stream, source_class, child_class, class_name, method_name, ref_line, target_package) special

Source code in codart\refactorings\pushdown_method.py
def __init__(self, common_token_stream: CommonTokenStream, source_class: str, child_class: str, class_name: str,
             method_name: str, ref_line: int, target_package: str):
    """


    """

    self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
    self.source_class = source_class
    self.child_class = child_class
    self.class_name = class_name
    self.method_name = method_name
    self.ref_line = ref_line
    self.target_package = target_package

    self.start = None
    self.stop = None
    self.is_safe = False
    self.need_cast = False
    self.variable = None
    self.detected_class = False
    self.detected_package = False
    self.import_end = None

PropagationNonStaticListener (PropagationListener)

Source code in codart\refactorings\pushdown_method.py
class PropagationNonStaticListener(PropagationListener):
    """


    """

    def exitCreatedName0(self, ctx: JavaParserLabeled.CreatedName0Context):
        if ctx.IDENTIFIER(0).getText() == self.source_class and self.is_safe:
            self.detected_class = True
            self.start = ctx.start
            self.stop = ctx.stop

    def enterMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context):
        if ctx.IDENTIFIER().getText() == self.method_name and self.is_safe and self.detected_class:
            # Change Name
            if ctx.start.line == self.ref_line:
                self.token_stream_rewriter.replaceRange(
                    from_idx=self.start.tokenIndex,
                    to_idx=self.stop.tokenIndex,
                    text=self.child_class
                )
            self.detected_class = False

    def exitVariableDeclarator(self, ctx: JavaParserLabeled.VariableDeclaratorContext):
        if self.detected_class and self.is_safe:
            self.variable = ctx.variableDeclaratorId().IDENTIFIER().getText()
            self.detected_class = False

    def enterExpression21(self, ctx: JavaParserLabeled.Expression21Context):
        if ctx.start.line == self.ref_line and self.is_safe:
            self.need_cast = True

    def exitExpression21(self, ctx: JavaParserLabeled.Expression21Context):
        if ctx.start.line == self.ref_line and self.is_safe:
            self.need_cast = False

    def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        self.enterExpression21(ctx)

    def exitExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        if self.is_safe and self.need_cast and self.variable is not None:
            # Type casting
            child = ctx.getChild(0).getChild(0)
            self.token_stream_rewriter.replaceRange(
                from_idx=child.start.tokenIndex,
                to_idx=child.stop.tokenIndex,
                text=f"(({self.child_class}) {self.variable})"
            )
            self.need_cast = False

PropagationStaticListener (PropagationListener)

Source code in codart\refactorings\pushdown_method.py
class PropagationStaticListener(PropagationListener):
    """


    """

    def __init__(self, *args, **kwargs):
        """


        """

        super(PropagationStaticListener, self).__init__(*args, **kwargs)
        self.detected_method = False

    def enterPrimary4(self, ctx: JavaParserLabeled.Primary4Context):
        if self.is_safe:
            self.start = ctx.start
            self.stop = ctx.stop
            self.detected_class = True

    def enterMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context):
        method_name = ctx.IDENTIFIER().getText()
        if method_name == self.method_name and self.is_safe:
            self.detected_method = True

    def exitMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context):
        if self.detected_method and self.detected_class:
            self.detected_class = False
            self.detected_method = False
            self.token_stream_rewriter.replaceRange(
                from_idx=self.start.tokenIndex,
                to_idx=self.stop.tokenIndex,
                text=f"{self.child_class}"
            )

PushDownMethodRefactoringListener (JavaParserLabeledListener)

Source code in codart\refactorings\pushdown_method.py
class PushDownMethodRefactoringListener(JavaParserLabeledListener):
    """


    """

    def __init__(self, common_token_stream: CommonTokenStream, source_class: str, source_method_text: str):
        """

        """

        self.source_method_text = source_method_text
        self.source_class = source_class
        self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
        self.is_safe = False

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        self.is_safe = ctx.IDENTIFIER().getText() == self.source_class

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        self.is_safe = not self.is_safe

    def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext):
        if self.is_safe:
            self.token_stream_rewriter.insertBefore(
                index=ctx.stop.tokenIndex,
                text=self.source_method_text + "\n",
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME
            )

__init__(self, common_token_stream, source_class, source_method_text) special

Source code in codart\refactorings\pushdown_method.py
def __init__(self, common_token_stream: CommonTokenStream, source_class: str, source_method_text: str):
    """

    """

    self.source_method_text = source_method_text
    self.source_class = source_class
    self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
    self.is_safe = False

main(udb_path, source_package, source_class, method_name, target_classes, *args, **kwargs)

The main API for the push-down method refactoring operation

Source code in codart\refactorings\pushdown_method.py
def main(udb_path, source_package, source_class, method_name, target_classes: list, *args, **kwargs):
    """

    The main API for the push-down method refactoring operation

    """
    target_package = source_package
    source_method = method_name

    main_file = None
    source_method_entity = None
    is_static = False
    propagation_files = []
    propagation_classes = []
    propagation_lines = []
    children_classes = []
    children_files = []

    # Initialize with understand
    db = und.open(udb_path)
    methods = db.ents("Java Method")
    for mth in methods:
        if mth.longname() == source_package + "." + source_class + "." + source_method:
            source_method_entity = mth
            for child_ref in mth.parent().refs("Extendby"):
                child_ref = child_ref.ent()
                if child_ref.simplename() in target_classes:
                    children_classes.append(child_ref.simplename())
                    children_files.append(child_ref.parent().longname())
            # print("mainfile : ", mth.parent().parent().longname())
            is_static = mth.kind().check("static")
            main_file = mth.parent().parent().longname()
            for ref in mth.refs("Callby"):
                propagation_files.append(ref.ent().parent().parent().longname())
                propagation_classes.append(ref.ent().parent().simplename())
                propagation_lines.append(ref.line())

    # Check pre-condition
    if not len(target_classes) == 1:
        logger.error(f"len(target_classes) is not 1.")
        db.close()
        return False

    if not len(children_classes) == 1:
        logger.error(f"len(children_classes) is not 1.")
        db.close()
        return False

    if not len(children_files) == 1:
        logger.error(f"len(children_files) is not 1.")
        db.close()
        return False

    for mth in methods:
        if mth.simplename() == source_method:
            if mth.parent().simplename() in target_classes:
                if mth.type() == source_method_entity.type():
                    if mth.kind() == source_method_entity.kind():
                        if mth.parameters() == source_method_entity.parameters():
                            logger.error("Duplicated method")
                            db.close()
                            return False

    for ref in source_method_entity.refs("use, call"):
        ref_ent = ref.ent()
        is_public = ref_ent.kind().check("public")
        if not is_public:
            logger.error("Has internal dependencies.")
            db.close()
            return False

    #  get text
    method_text = source_method_entity.contents()

    db.close()

    # Delete source method
    stream = FileStream(main_file, encoding='utf8', errors='ignore')
    lexer = JavaLexer(stream)
    token_stream = CommonTokenStream(lexer)
    parser = JavaParserLabeled(token_stream)
    parser.getTokenStream()
    parse_tree = parser.compilationUnit()
    my_listener = DeleteSourceListener(common_token_stream=token_stream, source_method=source_method)
    walker = ParseTreeWalker()
    walker.walk(t=parse_tree, listener=my_listener)
    # print(my_listener.token_stream_rewriter.getDefaultText())
    with open(main_file, mode='w', encoding='utf-8', newline='') as f:
        f.write(my_listener.token_stream_rewriter.getDefaultText())

    # Do the push down
    for child_file, child_class in zip(children_files, children_classes):
        stream = FileStream(child_file, encoding='utf8', errors='ignore')
        lexer = JavaLexer(stream)
        token_stream = CommonTokenStream(lexer)
        parser = JavaParserLabeled(token_stream)
        parser.getTokenStream()
        parse_tree = parser.compilationUnit()
        my_listener = PushDownMethodRefactoringListener(common_token_stream=token_stream,
                                                        source_class=child_class,
                                                        source_method_text=method_text)
        walker = ParseTreeWalker()
        walker.walk(t=parse_tree, listener=my_listener)
        # print(my_listener.token_stream_rewriter.getDefaultText())
        with open(child_file, mode='w', encoding='utf8', newline='') as f:
            f.write(my_listener.token_stream_rewriter.getDefaultText())

    # Propagation
    for file, _class, line in zip(propagation_files, propagation_classes, propagation_lines):
        stream = FileStream(file, encoding='utf8', errors='ignore')
        lexer = JavaLexer(stream)
        token_stream = CommonTokenStream(lexer)
        parser = JavaParserLabeled(token_stream)
        parser.getTokenStream()
        parse_tree = parser.compilationUnit()
        if is_static:
            my_listener = PropagationStaticListener(common_token_stream=token_stream, source_class=source_class,
                                                    child_class=children_classes[0], class_name=_class,
                                                    method_name=source_method,
                                                    ref_line=line, target_package=target_package)
        else:
            my_listener = PropagationNonStaticListener(common_token_stream=token_stream, source_class=source_class,
                                                       child_class=children_classes[0], class_name=_class,
                                                       method_name=source_method,
                                                       ref_line=line, target_package=target_package)
        walker = ParseTreeWalker()
        walker.walk(t=parse_tree, listener=my_listener)
        # print(my_listener.token_stream_rewriter.getDefaultText())
        with open(file, mode='w', encoding='utf8', errors='ignore', newline='') as f:
            f.write(my_listener.token_stream_rewriter.getDefaultText())

    return True

Push-down method 2

Introduction

The module implements a light-weight version of the push-down method refactoring described in pushdown_method.py

Pre-conditions:

Todo: Add pre-conditions

Post-conditions:

Todo: Add post-conditions

CutMethodListener (JavaParserLabeledListener)

Removes the method declaration from the parent class.

Source code in codart\refactorings\pushdown_method2.py
class CutMethodListener(JavaParserLabeledListener):
    """

    Removes the method declaration from the parent class.

    """
    def __init__(self, source_class, method_name, rewriter: TokenStreamRewriter):
        """

        Args:

            source_class: (str) Parent's class name.

            method_name: (str) Method's name.

            rewriter (TokenStreamRewriter): ANTLR's token stream rewriter.

        Returns:

            field_content (CutMethodListener): The full string of method declaration

        """

        self.source_class = source_class
        self.method_name = method_name
        self.rewriter = rewriter
        self.method_content = ""
        self.import_statements = ""

        self.detected_method = False
        self.is_source_class = False

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_name = ctx.IDENTIFIER().getText()
        if class_name == self.source_class:
            self.is_source_class = True

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_name = ctx.IDENTIFIER().getText()
        if self.is_source_class and class_name == self.source_class:
            self.is_source_class = False

    def enterImportDeclaration(self, ctx: JavaParserLabeled.ImportDeclarationContext):
        statement = self.rewriter.getText(
            program_name=self.rewriter.DEFAULT_PROGRAM_NAME,
            start=ctx.start.tokenIndex,
            stop=ctx.stop.tokenIndex
        )
        self.import_statements += statement + "\n"

    def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if self.is_source_class and ctx.IDENTIFIER().getText() == self.method_name:
            self.detected_method = True

    def exitClassBodyDeclaration2(self, ctx: JavaParserLabeled.ClassBodyDeclaration2Context):
        if self.detected_method:
            self.method_content = self.rewriter.getText(
                program_name=self.rewriter.DEFAULT_PROGRAM_NAME,
                start=ctx.start.tokenIndex,
                stop=ctx.stop.tokenIndex
            )
            self.rewriter.delete(
                program_name=self.rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.start.tokenIndex,
                to_idx=ctx.stop.tokenIndex
            )
            self.detected_method = False

__init__(self, source_class, method_name, rewriter) special

Parameters:

Name Type Description Default
source_class

(str) Parent's class name.

required
method_name

(str) Method's name.

required
rewriter TokenStreamRewriter

ANTLR's token stream rewriter.

required

Returns:

Type Description
field_content (CutMethodListener)

The full string of method declaration

Source code in codart\refactorings\pushdown_method2.py
def __init__(self, source_class, method_name, rewriter: TokenStreamRewriter):
    """

    Args:

        source_class: (str) Parent's class name.

        method_name: (str) Method's name.

        rewriter (TokenStreamRewriter): ANTLR's token stream rewriter.

    Returns:

        field_content (CutMethodListener): The full string of method declaration

    """

    self.source_class = source_class
    self.method_name = method_name
    self.rewriter = rewriter
    self.method_content = ""
    self.import_statements = ""

    self.detected_method = False
    self.is_source_class = False

PasteMethodListener (JavaParserLabeledListener)

Inserts method declaration to children classes.

Source code in codart\refactorings\pushdown_method2.py
class PasteMethodListener(JavaParserLabeledListener):
    """

    Inserts method declaration to children classes.

    """
    def __init__(self, source_class, method_content, import_statements, rewriter: TokenStreamRewriter):
        """

        Args:

            source_class (str): Child class name.

            method_content (str): Full string of the method declaration.

            rewriter (TokenStreamRewriter): ANTLR's token stream rewriter.

        Returns:

            object (PasteMethodListener): An instance of PasteMethodListener class.

        """

        self.source_class = source_class
        self.rewriter = rewriter
        self.method_content = method_content
        self.import_statements = import_statements
        self.is_source_class = False

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_name = ctx.IDENTIFIER().getText()
        if class_name == self.source_class:
            self.is_source_class = True

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_name = ctx.IDENTIFIER().getText()
        if self.is_source_class and class_name == self.source_class:
            self.is_source_class = False

    def exitPackageDeclaration(self, ctx: JavaParserLabeled.PackageDeclarationContext):
        self.rewriter.insertAfter(
            program_name=self.rewriter.DEFAULT_PROGRAM_NAME,
            index=ctx.stop.tokenIndex,
            text="\n" + self.import_statements
        )

    def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext):
        if self.is_source_class:
            self.rewriter.insertBefore(
                program_name=self.rewriter.DEFAULT_PROGRAM_NAME,
                index=ctx.stop.tokenIndex,
                text="\n\t" + self.method_content + "\n"
            )

__init__(self, source_class, method_content, import_statements, rewriter) special

Parameters:

Name Type Description Default
source_class str

Child class name.

required
method_content str

Full string of the method declaration.

required
rewriter TokenStreamRewriter

ANTLR's token stream rewriter.

required

Returns:

Type Description
object (PasteMethodListener)

An instance of PasteMethodListener class.

Source code in codart\refactorings\pushdown_method2.py
def __init__(self, source_class, method_content, import_statements, rewriter: TokenStreamRewriter):
    """

    Args:

        source_class (str): Child class name.

        method_content (str): Full string of the method declaration.

        rewriter (TokenStreamRewriter): ANTLR's token stream rewriter.

    Returns:

        object (PasteMethodListener): An instance of PasteMethodListener class.

    """

    self.source_class = source_class
    self.rewriter = rewriter
    self.method_content = method_content
    self.import_statements = import_statements
    self.is_source_class = False

main(udb_path, source_package, source_class, method_name, target_classes, *args, **kwargs)

The main API for the push-down method refactoring (version 2)

Source code in codart\refactorings\pushdown_method2.py
def main(udb_path, source_package, source_class, method_name, target_classes: list, *args, **kwargs):
    """

    The main API for the push-down method refactoring (version 2)

    """

    db = und.open(udb_path)
    source_class_ents = db.lookup(f"{source_package}.{source_class}", "Class")
    target_class_ents = []
    source_class_ent = None

    if len(source_class_ents) == 0:
        config.logger.error(f"Cannot find source class: {source_class}")
        db.close()
        return False
    else:
        for ent in source_class_ents:
            if ent.simplename() == source_class:
                source_class_ent = ent
                break
    if source_class_ent is None:
        config.logger.error(f"Cannot find source class: {source_class}")
        db.close()
        return False

    method_ent = db.lookup(f"{source_package}.{source_class}.{method_name}", "Method")
    if len(method_ent) == 0:
        config.logger.error(f"Cannot find method to pushdown: {method_name}")
        db.close()
        return False
    else:
        method_ent = method_ent[0]

    for ref in source_class_ent.refs("extendBy"):
        if ref.ent().simplename() not in target_classes:
            config.logger.error("Target classes are not children classes")
            db.close()
            return False
        target_class_ents.append(ref.ent())

    for ref in method_ent.refs("callBy"):
        if ref.file().simplename().split(".")[0] in target_classes:
            continue
        else:
            config.logger.error("Method has dependencies.")
            db.close()
            return False

    # Remove field from source class
    listener = parse_and_walk(
        file_path=source_class_ent.parent().longname(),
        listener_class=CutMethodListener,
        has_write=True,
        source_class=source_class,
        method_name=method_name,
        debug=False
    )

    # Insert field in children classes
    for target_class in target_class_ents:
        parse_and_walk(
            file_path=target_class.parent().longname(),
            listener_class=PasteMethodListener,
            has_write=True,
            source_class=target_class.simplename(),
            method_content=listener.method_content,
            import_statements=listener.import_statements,
            debug=False
        )
    db.close()