diff --git a/biokb/biokb.py b/biokb/biokb.py
index db7ea4123110fa0b98cee43e8a362b4711d5d3d2..11c65bb7e6e2e1180b02c01cf75ba63ad3b0abd1 100644
--- a/biokb/biokb.py
+++ b/biokb/biokb.py
@@ -58,7 +58,13 @@ class BioKBservice(TextMiningService):
 
         return values
 
-    def get_co_occurrences(self, entity: str, limit: int = 20) -> List[CoOccurrence]:
+    def get_co_occurrences(self, entity: str, limit: int = 20, types: List[str] = []) -> List[CoOccurrence]:
+
+        entity_types_filter = ''
+        if len(types) > 0:
+            types_str = ', '.join((f'<{t}>' for t in types))
+            entity_types_filter = f'FILTER (?e_type IN ({types_str}) )'
+
         entity = standarise_underscored_entity_code(entity)
         query = """
             select * where {
@@ -68,6 +74,8 @@ class BioKBservice(TextMiningService):
                     ?s <http://lcsb.uni.lu/biokb#containsEntity> <http://lcsb.uni.lu/biokb/entities/%ENTITY%> .
                     ?s a  <http://lcsb.uni.lu/biokb#Publication> .
                     ?s <http://lcsb.uni.lu/biokb#containsEntity> ?other_entity .
+                    ?other_entity a ?e_type .
+                    %ENTITY_TYPE_FILTER%
                 
                     OPTIONAL {?ss rdfs:subClassOf ?other_entity} .
                 
@@ -88,7 +96,8 @@ class BioKBservice(TextMiningService):
                 GROUP BY ?other_entity 
 
             } ORDER BY DESC(?count) LIMIT %LIMIT%
-        """.replace('%ENTITY%', entity).replace('%LIMIT%', str(limit))
+        """.replace('%ENTITY%', entity).replace('%LIMIT%', str(limit)).replace('%ENTITY_TYPE_FILTER%', entity_types_filter)
+        print(query)
         results = self._run_sparql_query(query)
         values = []
         values = []