Skip to content

Extract class

Introduction

The module implements the extract class refactoring to fix God/Large/Blob class code smell.

Extract a set of filed and methods from the class to a new class.

Pre and Post Conditions

Pre Conditions:

Post Conditions:

Changelog

v0.2.1

  • Fix bugs in getting entity.parent() None

DependencyPreConditionListener (JavaParserLabeledListener)

Source code in codart\refactorings\extract_class.py
class DependencyPreConditionListener(JavaParserLabeledListener):
    """

    """

    def __init__(self, common_token_stream: CommonTokenStream = None, class_identifier: str = None):
        self.enter_class = False
        self.token_stream = common_token_stream
        self.class_identifier = class_identifier
        # Move all the tokens in the source code in a buffer, token_stream_rewriter.
        if common_token_stream is not None:
            self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)
        else:
            raise TypeError("common_token_stream is None")
        self.field_dict = {}
        self.method_name = []  #
        self.method_no = 0
        self.connected_components = []

    # Groups methods in terms of their dependncies on the class attributes and one another
    def split_class(self):
        # 1- move the dictionary of fields into a new dictionary of methods operating on fields
        method_dict = {}
        for key, value in self.field_dict.items():
            for method in value:
                if not str(method) in method_dict:
                    method_dict[str(method)] = [key]
                else:
                    method_dict[str(method)].append(key)
        # 2- Group methods in terms of their dependencies on one another
        method_group = dict()

        # 3- Group methods in terms of their dependencies on the class attributes
        # Todo: To be modified
        for key, value in method_dict.items():
            if not str(value) in method_group:
                method_group[str(value)] = [key]
            else:
                method_group[str(value)].append(key)
        # --------------------------------------

        # 4- Create graph
        G = nx.DiGraph()
        for field, methods in self.field_dict.items():
            for method in methods:
                G.add_node(method[1], method_name=method[0])
                G.add_edge(field, method[1])

        # graph_visualization.draw(g=G)
        S = [G.subgraph(c).copy() for c in nx.weakly_connected_components(G)]

        for class_ in S:
            class_fields = [node for node in class_.nodes if class_.in_degree(node) == 0]
            class_methods = [class_.nodes[node]["method_name"] for node in class_.nodes if
                             class_.in_degree(node) > 0]
            self.connected_components.append(class_fields + class_methods)

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        if ctx.IDENTIFIER().getText() != self.class_identifier:
            return
        self.enter_class = True

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        self.enter_class = False
        self.split_class()

    def enterFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        if not self.enter_class:
            return
        field_id = ctx.variableDeclarators().variableDeclarator(i=0).variableDeclaratorId().IDENTIFIER().getText()
        self.field_dict[field_id] = []

    def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if not self.enter_class:
            return
        m = []
        m_name = ctx.IDENTIFIER().getText()
        self.method_no = self.method_no + 1
        m.append(m_name)
        m.append(self.method_no)
        self.method_name.append(m)

    def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if not self.enter_class:
            return

    def exitExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        try:
            if not self.enter_class:
                return
            if self.method_no == 0:
                return
            current_method = self.method_name[-1]
            variable_name = ctx.IDENTIFIER().getText()
            if variable_name not in self.field_dict:
                return
            if not current_method in self.field_dict[variable_name]:
                self.field_dict[variable_name].append(current_method)
        except:
            x = 0

ExtractClassRefactoringListener (JavaParserLabeledListener)

To implement extract class refactoring based on its actors.

Creates a new class and move fields and methods from the old class to the new one

Source code in codart\refactorings\extract_class.py
class ExtractClassRefactoringListener(JavaParserLabeledListener):
    """

    To implement extract class refactoring based on its actors.

    Creates a new class and move fields and methods from the old class to the new one

    """

    def __init__(self, common_token_stream: CommonTokenStream = None,
                 source_class: str = None, new_class: str = None,
                 moved_fields=None, moved_methods=None, method_map: dict = None
                 ):
        """


        """

        if method_map is None:
            self.method_map = {}
        else:
            self.method_map = method_map

        if moved_methods is None:
            self.moved_methods = []
        else:
            self.moved_methods = moved_methods
        if moved_fields is None:
            self.moved_fields = []
        else:
            self.moved_fields = moved_fields

        if common_token_stream is None:
            raise ValueError("common_token_stream is None")
        else:
            self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)

        if source_class is None:
            raise ValueError("source_class is None")
        else:
            self.source_class = source_class
        if new_class is None:
            raise ValueError("new_class is None")
        else:
            self.new_class = new_class

        self.is_source_class = False
        self.detected_field = None
        self.detected_method = None
        self.TAB = "\t"
        self.NEW_LINE = "\n"
        self.code = ""
        self.package_name = ""
        self.parameters = []
        self.object_name = self.new_class.replace(self.new_class, self.new_class[0].lower() + self.new_class[1:])
        self.modifiers = ""

        self.do_increase_visibility = False

        temp = []
        for method in moved_methods:
            if self.method_map.get(method) is not None and len(self.method_map.get(method)) > 0:
                temp.append(self.method_map.get(method))
        self.fields_to_increase_visibility = set().union(*temp)

    def enterPackageDeclaration(self, ctx: JavaParserLabeled.PackageDeclarationContext):
        if ctx.qualifiedName() and not self.package_name:
            self.package_name = ctx.qualifiedName().getText()
            self.code += f"package {self.package_name};{self.NEW_LINE}"

    def enterImportDeclaration(self, ctx: JavaParserLabeled.ImportDeclarationContext):
        i = self.token_stream_rewriter.getText(
            program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
            start=ctx.start.tokenIndex,
            stop=ctx.stop.tokenIndex
        )
        self.code += f"\n{i}\n"

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_identifier = str(ctx.children[1])
        if class_identifier == self.source_class:
            self.is_source_class = True
            self.code += self.NEW_LINE * 2
            self.code += f"// New class({self.new_class}) generated by CodART" + self.NEW_LINE
            self.code += f"class {self.new_class}{self.NEW_LINE}" + "{" + self.NEW_LINE
        else:
            self.is_source_class = False

    def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext):
        if self.is_source_class:
            self.token_stream_rewriter.insertAfterToken(
                token=ctx.start,
                text="\n\t" + f"public {self.new_class} {self.object_name} = new {self.new_class}();",
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME
            )

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_identifier = str(ctx.children[1])
        if class_identifier == self.source_class:
            self.code += "}"
            self.is_source_class = False
        else:
            self.is_source_class = True

    def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext):
        pass

    def enterVariableDeclaratorId(self, ctx: JavaParserLabeled.VariableDeclaratorIdContext):
        if not self.is_source_class:
            return None
        field_identifier = ctx.IDENTIFIER().getText()
        if field_identifier in self.moved_fields:
            self.detected_field = field_identifier

    def enterFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        field_names = ctx.variableDeclarators().getText().split(",")
        for field in field_names:
            if field in self.fields_to_increase_visibility:
                for modifier in ctx.parentCtx.parentCtx.modifier():
                    if modifier.getText() == "private":
                        self.token_stream_rewriter.replaceSingleToken(
                            token=modifier.start,
                            text="public "
                        )

    def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        if not self.is_source_class:
            return None
        if not self.detected_field:
            return None

        field_names = ctx.variableDeclarators().getText()
        field_names = field_names.split(',')
        grand_parent_ctx = ctx.parentCtx.parentCtx
        if any([self.detected_field in i for i in field_names]):
            field_type = ctx.typeType().getText()

            if len(field_names) == 1:
                # Todo: Requires better handling
                st = f"public {field_type} {field_names[0]};{self.NEW_LINE}"
                if '=new' in st and '()' in st:
                    st = st.replace('new', 'new ')
                self.code += st
            else:
                # Todo: Requires better handling
                st = f"public {field_type} {self.detected_field};{self.NEW_LINE}"
                if '=new' in st and '()' in st:
                    st = st.replace('new', 'new ')
                self.code += st

            # delete field from source class
            for fi in field_names:
                if self.detected_field in fi:
                    field_names.remove(fi)
                # Todo: Requires better handling
                if fi == '1))' or fi == ' 1))':
                    field_names.remove(fi)

            if field_names:
                self.token_stream_rewriter.replaceRange(
                    from_idx=grand_parent_ctx.start.tokenIndex,
                    to_idx=grand_parent_ctx.stop.tokenIndex,
                    text=f"public {field_type} {','.join(field_names)};\n"
                )
            else:
                self.token_stream_rewriter.delete(
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                    from_idx=grand_parent_ctx.start.tokenIndex,
                    to_idx=grand_parent_ctx.stop.tokenIndex
                )
            self.detected_field = None

    def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if not self.is_source_class:
            return None
        method_identifier = ctx.IDENTIFIER().getText()
        if method_identifier in self.moved_methods:
            self.detected_method = method_identifier

    def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext):
        if self.detected_method:
            self.parameters.append(
                ctx.variableDeclaratorId().IDENTIFIER().getText()
            )

    def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if not self.is_source_class:
            return None
        method_identifier = ctx.IDENTIFIER().getText()
        if self.detected_method == method_identifier:
            start_index = ctx.start.tokenIndex
            stop_index = ctx.stop.tokenIndex
            method_text = self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=start_index,
                stop=stop_index
            )
            self.code += self.NEW_LINE + ("public " + method_text + self.NEW_LINE)
            # delegate method body in source class
            if self.method_map.get(method_identifier):
                self.parameters.append("this")

            self.token_stream_rewriter.replaceRange(
                from_idx=ctx.methodBody().start.tokenIndex,
                to_idx=stop_index,
                text="{" + f"\nreturn this.{self.object_name}.{self.detected_method}(" + ",".join(
                    self.parameters) + ");\n" + "}"
            )
            self.parameters = []
            self.detected_method = None

    def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        identifier = ctx.IDENTIFIER()
        if identifier is not None:
            if identifier.getText() in self.moved_fields and self.detected_method not in self.moved_methods:
                # Found field usage!
                self.token_stream_rewriter.insertBeforeToken(
                    token=ctx.stop,
                    text=self.object_name + ".",
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME
                )

__init__(self, common_token_stream=None, source_class=None, new_class=None, moved_fields=None, moved_methods=None, method_map=None) special

Source code in codart\refactorings\extract_class.py
def __init__(self, common_token_stream: CommonTokenStream = None,
             source_class: str = None, new_class: str = None,
             moved_fields=None, moved_methods=None, method_map: dict = None
             ):
    """


    """

    if method_map is None:
        self.method_map = {}
    else:
        self.method_map = method_map

    if moved_methods is None:
        self.moved_methods = []
    else:
        self.moved_methods = moved_methods
    if moved_fields is None:
        self.moved_fields = []
    else:
        self.moved_fields = moved_fields

    if common_token_stream is None:
        raise ValueError("common_token_stream is None")
    else:
        self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)

    if source_class is None:
        raise ValueError("source_class is None")
    else:
        self.source_class = source_class
    if new_class is None:
        raise ValueError("new_class is None")
    else:
        self.new_class = new_class

    self.is_source_class = False
    self.detected_field = None
    self.detected_method = None
    self.TAB = "\t"
    self.NEW_LINE = "\n"
    self.code = ""
    self.package_name = ""
    self.parameters = []
    self.object_name = self.new_class.replace(self.new_class, self.new_class[0].lower() + self.new_class[1:])
    self.modifiers = ""

    self.do_increase_visibility = False

    temp = []
    for method in moved_methods:
        if self.method_map.get(method) is not None and len(self.method_map.get(method)) > 0:
            temp.append(self.method_map.get(method))
    self.fields_to_increase_visibility = set().union(*temp)